From 303d4594f4c086e15f2cf5fc7fcb00cae6a49c15 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 2 Apr 2024 11:02:52 -0500 Subject: [PATCH 1/7] improved zeroing (#1221) --- example/15_grouped_gemm/CMakeLists.txt | 4 +- .../grouped_gemm_xdl_fixed_nk_fp16.cpp | 10 +- ...=> grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp} | 4 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 215 ++++++++---- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 325 +++++++++++------- 5 files changed, 367 insertions(+), 191 deletions(-) rename example/15_grouped_gemm/{grouped_gemm_xdl_fixed_nk_fp8.cpp => grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp} (99%) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 84040fcf5c..550dafb066 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 2c1feafce3..1a2bcfb33e 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; -using EDataType = F32; +using EDataType = F16; using ALayout = Row; using BLayout = Col; @@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on struct ProblemSize final @@ -298,9 +298,9 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(256); - problem_size.Ks.push_back(128); + problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp similarity index 99% rename from example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 9fd63cba77..0a63a29843 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; using BDataType = F8; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using DsDataType = ck::Tuple<>; using EDataType = F16; @@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on struct ProblemSize final diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index d197c56ab8..c98ec6e2aa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -23,6 +23,7 @@ namespace device { template (gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } id_off += grid_size_grp; id_local += grid_size_grp; @@ -193,8 +224,11 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeType = ADataType, + typename ALDSType = ComputeType, + typename BLDSType = ComputeType> struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK{}; static constexpr auto I2 = Number<2>{}; + using AComputeType = ComputeType; + using BComputeType = ComputeType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, - ComputeType, + AComputeType, + BComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; + LoopSched, + PipelineVer, + ALDSType, + BLDSType>; template struct OffsettedBlockToCTileMapMLoops @@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK, - GemmSpec, - ALayout, - BLayout, - DsLayout, - ELayout, - DsDataType, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - e_global_memory_operation_, - has_main_k_block_loop_>; + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - reinterpret_cast(arg.p_workspace_), - arg.barrier_size_grp_, - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } }; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto Set = InMemoryDataOperationEnum::Set; - // For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced - // in IsSupportedArgument function + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function if constexpr(std::is_same::value) { if(has_main_k_block_loop) @@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + PipelineVersion PipelineVer, + typename ALDSType, + typename BLDSType> struct GridwiseGemmMultipleD_xdl_splitk_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeType), + return math::max(a_block_space_size_aligned * sizeof(ALDSType) + + b_block_space_size_aligned * sizeof(BLDSType), c_block_size * sizeof(CShuffleDataType)); } @@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle InMemoryDataOperationEnum EGlobalMemoryDataOperation, index_t NumDTensor_, typename DsDataType_, + bool Zeroing, typename AGridDesc_KBatch_AK0_M_AK1, typename BGridDesc_KBatch_BK0_N_BK1, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, @@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - ComputeType, + ALDSType, decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BDataType, - ComputeType, + BLDSType, decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, - ComputeType, + ALDSType, + BLDSType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeType, + BComputeType>(); -#if 1 - if(block_work_idx[I0] == 0) + if constexpr(Zeroing) { - const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; - const index_t numNThreads = NPerBlock / nThreadSize; - const index_t numMThreads = BlockSize / numNThreads; - const index_t mThreadSize = MPerBlock / numMThreads; - - const index_t m_tid = get_thread_local_1d_id() / numNThreads; - const index_t n_tid = get_thread_local_1d_id() % numNThreads; - - auto c_thread_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, Number{})); - - StaticBuffer - e_thread_zero_buf; - - auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< - EDataType, - EDataType, - decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), - decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), - ck::tensor_operation::element_wise::PassThrough, - Sequence<1, mThreadSize, 1, nThreadSize>, - Sequence<0, 1, 2, 3>, - 3, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - InMemoryDataOperationEnum::Set, - 1, - true>{e_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], - m_tid * mThreadSize, - block_work_idx[I2], - n_tid * nThreadSize), - ck::tensor_operation::element_wise::PassThrough{}}; - - c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, - make_tuple(I0, I0, I0, I0), - e_thread_zero_buf, - e_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_buf); - - __syncthreads(); - - if(threadIdx.x == 0) + if(block_work_idx[I0] == 0) { - atomicAdd(barrier_count_finished, 1); + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __builtin_amdgcn_s_barrier(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } } } -#endif auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); @@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // shuffle C and write out { - if(threadIdx.x == 0) + if constexpr(Zeroing) { - while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + __builtin_amdgcn_s_barrier(); } - __syncthreads(); - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle } }); - if(threadIdx.x == 0) + if constexpr(Zeroing) { - index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - - if(k_id_finished_t == KBatch) + if(threadIdx.x == 0) { - *barrier_count_finished = 0; + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } } } } } + template + __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run, true>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + template ( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - barrier_count_finished, - KBatch, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_kbatch_ak0_m_ak1, - b_grid_desc_kbatch_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); - } - else - { - Run>( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - barrier_count_finished, - KBatch, - a_element_op, - b_element_op, - ck::tensor_operation::element_wise::PassThrough{}, - a_grid_desc_kbatch_ak0_m_ak1, - b_grid_desc_kbatch_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); - } + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + nullptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } }; From ae57e5938e7fdfd049055a855910f66054e04163 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:42:17 -0700 Subject: [PATCH 2/7] Split the instances by architecture. (#1223) * parse examples inside the add_example_executable function * fix the example 64 cmake file * add xdl flag to the gemm_bias_softmax_gemm_permute example * add filtering of tests based on architecture type * enable test_grouped_gemm for gfx9 only * enable test_transpose only for gfx9 * only linnk test_transpose if it gets built * split the gemm instances by architectures * split gemm_bilinear,grouped_conv_bwd_weight instances by targets * split instances by architecture * split grouped_conv instances by architecture * fix clang format * fix the if-else logic in group_conv headers * small fix for grouped convolution instances * fix the grouped conv bwd weight dl instances * fix client examples * only enable client examples 3 and 4 on gfx9 * set the gfx9 macro * make sure the architecture macros are set by cmake * use separate set of xdl/wmma flags for host code * sinmplify the main cmake file * add conv_fwd_bf8 instance declaration --- CMakeLists.txt | 16 + .../02_gemm_add_add_fastgelu/CMakeLists.txt | 34 +- .../03_gemm_layernorm/CMakeLists.txt | 10 +- client_example/04_contraction/CMakeLists.txt | 23 +- .../07_grouped_convnd_fwd/CMakeLists.txt | 10 +- .../08_fused_attention/CMakeLists.txt | 10 +- client_example/09_quantization/CMakeLists.txt | 30 +- .../15_convnd_bwd_data/CMakeLists.txt | 10 +- .../15_gemm_add_multiply/CMakeLists.txt | 7 +- .../17_grouped_gemm_fastgelu/CMakeLists.txt | 6 +- client_example/20_splitk_gemm/CMakeLists.txt | 2 +- .../21_grouped_gemm_bias/CMakeLists.txt | 6 +- client_example/22_grouped_gemm/CMakeLists.txt | 18 +- .../24_grouped_conv_activation/CMakeLists.txt | 2 + client_example/25_wrapper/CMakeLists.txt | 4 +- client_example/CMakeLists.txt | 5 +- example/01_gemm/CMakeLists.txt | 11 +- example/02_gemm_bilinear/CMakeLists.txt | 23 +- example/03_gemm_bias_relu/CMakeLists.txt | 9 +- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 35 +- example/09_convnd_fwd/CMakeLists.txt | 23 +- .../CMakeLists.txt | 34 +- example/14_gemm_quantization/CMakeLists.txt | 13 +- .../CMakeLists.txt | 67 +- example/17_convnd_bwd_data/CMakeLists.txt | 15 +- .../20_grouped_conv_bwd_weight/CMakeLists.txt | 34 +- example/21_gemm_layernorm/CMakeLists.txt | 16 +- example/26_contraction/CMakeLists.txt | 32 +- .../CMakeLists.txt | 5 +- .../CMakeLists.txt | 51 +- example/31_batched_gemm_gemm/CMakeLists.txt | 20 +- .../CMakeLists.txt | 14 +- example/35_splitK_gemm/CMakeLists.txt | 43 +- .../CMakeLists.txt | 31 +- .../40_conv2d_fwd_quantization/CMakeLists.txt | 39 +- .../41_grouped_conv_conv_fwd/CMakeLists.txt | 20 +- .../CMakeLists.txt | 9 +- ...=> gemm_bias_softmax_gemm_permute_xdl.cpp} | 0 example/52_im2col_col2im/CMakeLists.txt | 18 +- example/60_gemm_multi_ABD/CMakeLists.txt | 9 +- .../61_contraction_multi_ABD/CMakeLists.txt | 9 +- example/62_convnd_activ/CMakeLists.txt | 19 +- example/64_fpAintB_gemm/CMakeLists.txt | 8 +- example/CMakeLists.txt | 36 + include/ck/ck.hpp | 14 +- .../tensor_operation_instance/gpu/gemm.hpp | 583 ++------- .../gpu/gemm_bilinear.hpp | 8 +- .../tensor_operation_instance/gpu/gemm_dl.inc | 167 +++ .../gpu/gemm_wmma.inc | 34 + .../gpu/gemm_xdl.inc | 238 ++++ .../gpu/grouped_convolution_backward_data.hpp | 673 +++------- ...grouped_convolution_backward_data_wmma.inc | 243 ++++ .../grouped_convolution_backward_data_xdl.inc | 216 ++++ .../grouped_convolution_backward_weight.hpp | 889 ++++--------- ...grouped_convolution_backward_weight_dl.inc | 243 ++++ ...ouped_convolution_backward_weight_wmma.inc | 114 ++ ...rouped_convolution_backward_weight_xdl.inc | 228 ++++ .../gpu/grouped_convolution_forward.hpp | 1121 +++-------------- .../gpu/grouped_convolution_forward_dl.inc | 73 ++ .../gpu/grouped_convolution_forward_wmma.inc | 480 +++++++ .../gpu/grouped_convolution_forward_xdl.inc | 357 ++++++ .../gpu/CMakeLists.txt | 35 + .../gpu/batched_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../batched_gemm_bias_permute/CMakeLists.txt | 1 + .../gpu/batched_gemm_gemm/CMakeLists.txt | 1 + .../gpu/batched_gemm_reduce/CMakeLists.txt | 1 + .../batched_gemm_softmax_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/contraction_bilinear/CMakeLists.txt | 1 + .../gpu/contraction_scale/CMakeLists.txt | 1 + .../gpu/conv1d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_fwd/CMakeLists.txt | 1 + .../gpu/conv2d_fwd_bias_relu/CMakeLists.txt | 1 + .../conv2d_fwd_bias_relu_add/CMakeLists.txt | 1 + .../gpu/conv3d_bwd_data/CMakeLists.txt | 1 + .../gpu/gemm_add/CMakeLists.txt | 1 + .../gpu/gemm_add_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_multiply/CMakeLists.txt | 1 + .../gpu/gemm_add_relu/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/gemm_add_silu/CMakeLists.txt | 1 + .../gpu/gemm_bias_add_reduce/CMakeLists.txt | 1 + .../gpu/gemm_bilinear/CMakeLists.txt | 1 + .../gpu/gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_multiply_add/CMakeLists.txt | 1 + .../gpu/gemm_reduce/CMakeLists.txt | 1 + .../gpu/gemm_splitk/CMakeLists.txt | 1 + .../gpu/gemm_streamk/CMakeLists.txt | 1 + .../grouped_conv1d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_data/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + .../grouped_conv3d_bwd_data/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_fwd_scale/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/grouped_gemm/CMakeLists.txt | 1 + .../gpu/grouped_gemm_bias/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 1 + .../gpu/quantization/CMakeLists.txt | 1 + profiler/src/CMakeLists.txt | 171 +-- test/CMakeLists.txt | 25 +- test/batched_gemm/CMakeLists.txt | 11 +- ...hed_gemm.cpp => test_batched_gemm_xdl.cpp} | 0 test/batched_gemm_gemm/CMakeLists.txt | 19 +- ...pp => test_batched_gemm_gemm_fp16_xdl.cpp} | 0 test/batched_gemm_reduce/CMakeLists.txt | 13 +- ...6.cpp => batched_gemm_reduce_fp16_xdl.cpp} | 0 test/batched_gemm_softmax_gemm/CMakeLists.txt | 19 +- ...st_batched_gemm_softmax_gemm_fp16_xdl.cpp} | 0 .../CMakeLists.txt | 50 +- ...mm_bias_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...mm_bias_softmax_gemm_permute_fp16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_fp16_xdl.cpp} | 0 test/contraction/CMakeLists.txt | 21 +- ...cpp => test_contraction_interface_xdl.cpp} | 0 ...ntraction.cpp => test_contraction_xdl.cpp} | 0 test/convnd_bwd_data/CMakeLists.txt | 11 +- ...d_bwd_data.cpp => convnd_bwd_data_xdl.cpp} | 0 test/convnd_fwd/CMakeLists.txt | 11 +- .../{convnd_fwd.cpp => convnd_fwd_xdl.cpp} | 0 test/gemm_add/CMakeLists.txt | 24 +- ...elu.cpp => test_gemm_add_fastgelu_xdl.cpp} | 2 +- ...dd_relu.cpp => test_gemm_add_relu_xdl.cpp} | 2 +- ...dd_silu.cpp => test_gemm_add_silu_xdl.cpp} | 2 +- ...est_gemm_add.hpp => test_gemm_add_xdl.hpp} | 0 test/gemm_layernorm/CMakeLists.txt | 19 +- ..._gemm_add_relu_add_layernorm_fp16_xdl.cpp} | 0 test/gemm_reduce/CMakeLists.txt | 2 +- ...duce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} | 0 test/gemm_split_k/CMakeLists.txt | 9 +- ...mm_splitk.cpp => test_gemm_splitk_xdl.cpp} | 0 test/grouped_convnd_bwd_data/CMakeLists.txt | 31 +- ...test_grouped_convnd_bwd_data_xdl_wmma.cpp} | 0 test/grouped_convnd_bwd_weight/CMakeLists.txt | 26 +- ...st_grouped_convnd_bwd_weight_xdl_wmma.cpp} | 0 test/grouped_convnd_fwd/CMakeLists.txt | 16 +- ...ti_d_interface_compatibility_xdl_wmma.cpp} | 0 ...p => test_grouped_convnd_fwd_xdl_wmma.cpp} | 0 test/grouped_gemm/CMakeLists.txt | 27 +- ...pp => test_grouped_gemm_interface_xdl.cpp} | 0 ...k.cpp => test_grouped_gemm_splitk_xdl.cpp} | 0 test/normalization_bwd_data/CMakeLists.txt | 13 +- .../CMakeLists.txt | 13 +- test/permute_scale/CMakeLists.txt | 6 +- test/transpose/CMakeLists.txt | 13 +- ...t_transpose.cpp => test_transpose_xdl.cpp} | 0 test/wrapper/CMakeLists.txt | 6 +- ...per_gemm.cpp => test_wrapper_gemm_xdl.cpp} | 0 160 files changed, 3770 insertions(+), 3392 deletions(-) rename example/47_gemm_bias_softmax_gemm_permute/{gemm_bias_softmax_gemm_permute.cpp => gemm_bias_softmax_gemm_permute_xdl.cpp} (100%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc rename test/batched_gemm/{test_batched_gemm.cpp => test_batched_gemm_xdl.cpp} (100%) rename test/batched_gemm_gemm/{test_batched_gemm_gemm_fp16.cpp => test_batched_gemm_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_reduce/{batched_gemm_reduce_fp16.cpp => batched_gemm_reduce_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm/{test_batched_gemm_softmax_gemm_fp16.cpp => test_batched_gemm_softmax_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp => test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp => test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_bf16.cpp => test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_fp16.cpp => test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/contraction/{test_contraction_interface.cpp => test_contraction_interface_xdl.cpp} (100%) rename test/contraction/{test_contraction.cpp => test_contraction_xdl.cpp} (100%) rename test/convnd_bwd_data/{convnd_bwd_data.cpp => convnd_bwd_data_xdl.cpp} (100%) rename test/convnd_fwd/{convnd_fwd.cpp => convnd_fwd_xdl.cpp} (100%) rename test/gemm_add/{test_gemm_add_fastgelu.cpp => test_gemm_add_fastgelu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_relu.cpp => test_gemm_add_relu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_silu.cpp => test_gemm_add_silu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add.hpp => test_gemm_add_xdl.hpp} (100%) rename test/gemm_layernorm/{test_gemm_add_relu_add_layernorm_fp16.cpp => test_gemm_add_relu_add_layernorm_fp16_xdl.cpp} (100%) rename test/gemm_reduce/{gemm_reduce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} (100%) rename test/gemm_split_k/{test_gemm_splitk.cpp => test_gemm_splitk_xdl.cpp} (100%) rename test/grouped_convnd_bwd_data/{test_grouped_convnd_bwd_data.cpp => test_grouped_convnd_bwd_data_xdl_wmma.cpp} (100%) rename test/grouped_convnd_bwd_weight/{test_grouped_convnd_bwd_weight.cpp => test_grouped_convnd_bwd_weight_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp => test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd.cpp => test_grouped_convnd_fwd_xdl_wmma.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_interface.cpp => test_grouped_gemm_interface_xdl.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_splitk.cpp => test_grouped_gemm_splitk_xdl.cpp} (100%) rename test/transpose/{test_transpose.cpp => test_transpose_xdl.cpp} (100%) rename test/wrapper/{test_wrapper_gemm.cpp => test_wrapper_gemm_xdl.cpp} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index bdeba33eac..3c77f520ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,22 @@ if(GPU_TARGETS) else() message("Building CK for the following targets: ${AMDGPU_TARGETS}") endif() + +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index 772b699955..4ba86026b2 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,27 +1,29 @@ -add_custom_target(client_gemm_fastgelu_examples) +if(GPU_TARGETS MATCHES "gfx9") + add_custom_target(client_gemm_fastgelu_examples) -add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) + target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu + add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) -add_custom_target(client_gemm_fastgelu_generic_examples) + add_custom_target(client_gemm_fastgelu_generic_examples) -add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) + target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) +endif() diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 94b4576f64..8fedc84635 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) + target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) -add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) + add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) + target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +endif() diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index cd4a95124c..13c0375846 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,16 @@ -add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) + target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) + target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) + target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) - -add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) + target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) + target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 40f1bba064..710eca9f49 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) + target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) -add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) + add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) + target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 9472be07b5..4bcde367dc 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_fused_attention fused_attention.cpp) + target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_fused_attention_bias fused_attention_bias.cpp) + target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 65ad642ce2..d2d3a427e8 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) + add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_gemm_quantization gemm_quantization.cpp) + target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index f35cd82d79..8fc62bc2bb 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) -add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) + add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index 4b4d762003..a683f78571 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ - -add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) + target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index fd315afbd2..39bef71814 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) + target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a3dc853767..05fcaa8103 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) endif() diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt index 92e31495c2..a09921e50a 100644 --- a/client_example/21_grouped_gemm_bias/CMakeLists.txt +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 0c3cb956f0..1e1c39681e 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -1,11 +1,13 @@ -add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 074dcd9b97..e79dee9f7d 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -1,3 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") # Fwd scaleadd scaleadd relu add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) @@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke add_executable(client_grouped_convnd_bwd_data_scale_fp16 grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index fdfc1d8d2e..b1e9d20bfd 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx9") add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 753f5e5ae5..3aa9efa315 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,10 @@ else() endif() endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +if(GPU_TARGETS MATCHES "gfx9") + find_package(composable_kernel COMPONENTS device_contraction_operations) +endif() find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2fa8e77462..39e3f2a2bd 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) -endif() add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) @@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -# FIXME: re-enable this example as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) @@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index d82c42d5a9..2c20b96eee 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,20 +1,3 @@ -list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) - add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) -endif() -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 2f5cba924d..35c54abac0 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 33ac1e7e77..ab19f819e8 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,29 +1,20 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) +add_custom_target(example_gemm_add_add_fastgelu_xdl) +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - set(target 1) - endif() -endforeach() - -set(gpu_list "") +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 195f1857ed..61e9a43c3a 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,19 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) - add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 222a3b7c0b..ef8bea1850 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,25 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) +add_custom_target(example_convnd_fwd_reduce_xdl) +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 9793e8b8a0..8703fa3ed7 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,12 +1,3 @@ -# dlops add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -# xdlops -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) - add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 5955e1d6cb..1e12c16f30 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,48 +1,41 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - add_example_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) +add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 7c6d10d8a0..39f9fb8ec0 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,14 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - endif() - set(target 1) - endif() -endforeach() +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +endif() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index c28fca6fa2..497ea19e11 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,29 +1,15 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) +add_custom_target(example_grouped_conv_bwd_weight) +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - set(target 1) - endif() +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) - set(target 1) - endif() -endforeach() - -add_custom_target(example_grouped_conv_bwd_weight_dl) +add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) -add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index e231bc619b..2eb7052e1e 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,12 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9c..f3d30cea2a 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear) # FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) # FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) # FP16 add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) # BF16 add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) +add_example_dependencies(example_contraction example_contraction_scale) +add_example_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index f343cc1910..ac54aebdc2 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,2 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 3a8c2ef52f..7acb1a1907 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,40 +1,23 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 - - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 93f16c945f..8b648a7f73 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index c6cca7b586..519f539106 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,11 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) - add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) - add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) - add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) +add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) +add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) +add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 5277b32f63..9a62d85ace 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,32 +1,23 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_splitK_gemm_xdl) +add_custom_target(example_splitK_gemm_xdl) +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) +add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) - add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) +add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) - add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) +add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) - add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() - - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 1ae179e950..72e6959649 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,27 +1,10 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) +add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) - set(target 1) - endif() -endforeach() - -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - - add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) - - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 2d804cafe9..991c1e464b 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,24 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - # Conv perlayer quantization - add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) - # Conv perchannel quantization - add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - # Conv + bias + relu perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - # Conv + bias + relu perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - # Conv + bias + tanh perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - # Conv + bias + tanh perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index ae251e88d2..8ab56b21a6 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index 14432f6e23..df1956ca62 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp similarity index 100% rename from example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp rename to example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index 4dc6c8b4e0..63ee1d4312 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,15 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_im2col_col2im) +add_custom_target(example_im2col_col2im) - add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) - add_example_dependencies(example_im2col_col2im example_image_to_column_f32) +add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) +add_example_dependencies(example_im2col_col2im example_image_to_column_f32) - add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) - add_example_dependencies(example_im2col_col2im example_column_to_image_f32) - - set(target 1) - endif() -endforeach() +add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) +add_example_dependencies(example_im2col_col2im example_column_to_image_f32) diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index 57bc0b33ef..d3974897fe 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt index 42500b64e6..1b8bd4cad2 100644 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 6eaddd3ff7..5a35f9b608 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -2,16 +2,9 @@ add_subdirectory(binary) add_subdirectory(multi_AB) add_subdirectory(unary) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl) - # ScaleAdd ScaleAdd Relu - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_xdl) +# ScaleAdd ScaleAdd Relu +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt index 89cc2d7f62..b3c77b3bd7 100644 --- a/example/64_fpAintB_gemm/CMakeLists.txt +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_fpAintB_gemm_wmma) - add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) - add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) -endif() +add_custom_target(example_fpAintB_gemm_wmma) +add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) +add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c19ba93b69..5465adb779 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,12 @@ include_directories(BEFORE add_custom_target(examples) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(FILE_NAME) + add_dependencies(EXAMPLE_NAME FILE_NAME) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) @@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index c93d1d0639..0bda8b7590 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -45,6 +45,10 @@ #endif // define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif @@ -62,8 +66,7 @@ // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -75,8 +78,7 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ - defined(__gfx94__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -89,7 +91,7 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_MFMA #endif @@ -120,7 +122,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index ee9d977096..50c18fc22e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -12,398 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "gemm_dl.inc" +#endif +#ifdef CK_USE_WMMA +#include "gemm_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS) -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS) -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS) -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP64 -void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP8 -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif - -void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } +#ifdef CK_ENABLE_FP16 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + } + } #endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); + } + } +#endif +#endif // DL_KERNELS + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif +#endif + +#ifdef CK_USE_XDL + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( op_ptrs); @@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( op_ptrs); @@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( op_ptrs); @@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( op_ptrs); @@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif @@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); -#endif } } #endif @@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs); } } +#endif #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 1a518a5302..6ee88bd855 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 +#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA) void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector> op_ptrs; -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc new file mode 100644 index 0000000000..44a11f6284 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if defined(CK_ENABLE_FP16) +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_FP32) +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_INT8) +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc new file mode 100644 index 0000000000..c97298c258 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc new file mode 100644 index 0000000000..82a1dc425a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP64 +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP8 +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 09885ccd90..9a70a47274 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -10,439 +10,18 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_data_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_data_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -// conv2d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); - -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -// conv3d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - std::vector>>& instances); -#endif template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 2) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } } - else if constexpr(NumDimSpatial == 3) + if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( op_ptrs); @@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( op_ptrs); } #endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( op_ptrs); @@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory< #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc new file mode 100644 index 0000000000..fb2407bcc3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc new file mode 100644 index 0000000000..7ad0218410 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); +#endif + +// conv3d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index b8ca2c5fac..dc56b8f4b2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,565 +12,20 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_backward_weight_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_weight_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_weight_wmma.inc" +#endif namespace ck { namespace tensor_operation { namespace device { namespace instance { -// xdl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef DL_KERNELS -// dl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#endif - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(NumDimSpatial == 1) { if constexpr(is_same_v && is_same_v && @@ -621,10 +77,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -632,10 +85,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -644,19 +94,14 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - op_ptrs); } #endif } if constexpr(is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -683,6 +128,174 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif // DL_KERNELS +#ifdef CK_USE_XDL + if constexpr(NumDimSpatial == 1) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } #endif } } @@ -696,10 +309,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); } @@ -709,10 +318,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( op_ptrs); } @@ -723,10 +328,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); } @@ -740,10 +341,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); } @@ -753,10 +350,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); } @@ -767,10 +360,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( op_ptrs); } @@ -787,10 +376,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); } @@ -800,16 +385,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -818,13 +395,70 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } +#endif + } + } +#endif +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && @@ -842,50 +476,17 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( op_ptrs); } #endif -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v && @@ -897,18 +498,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - op_ptrs); - } #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc new file mode 100644 index 0000000000..59190a13e5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc new file mode 100644 index 0000000000..315547ca56 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv3d backward weight +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc new file mode 100644 index 0000000000..5562d236e6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index b9712542a8..24a5f9a5cb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -12,908 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_forward_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_BF16 -// grouped conv1d forward, GNWC/GKXC/GNWK -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -// grouped conv2d forward, NHWGC/GKYXC/NHWGK -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); - -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - template > op_ptrs; +#ifdef DL_KERNELS + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } +#endif + } + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } +#endif + } +#endif // DL_KERNELS + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && is_same_v) { @@ -1000,35 +154,13 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); } #endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1037,23 +169,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); - } -#endif } if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { - #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1061,15 +181,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1077,15 +188,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1093,16 +195,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); - } #endif } @@ -1121,12 +213,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1142,11 +228,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -1188,12 +269,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1209,6 +284,99 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); + } +#endif + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( @@ -1217,6 +385,7 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc new file mode 100644 index 0000000000..0ea24d0929 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc new file mode 100644 index 0000000000..942674ef99 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +// grouped conv1d forward, GNWC/GKXC/GNWK +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0a12e1c49e..c035e7e564 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build XDL instances if gfx9 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_library(${INSTANCE_NAME} OBJECT ${ARGN}) @@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 69b6ddc754..1227a77a38 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(BATCHED_GEMM_INSTANCES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index d0e9b265af..5c8470f7cb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index cd9c95c066..8082a8c275 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 865a31e79a..2aa607429d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 28226fabac..51bbdf1d7c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 6244477e16..e43eb07fb6 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3fd4e03449..f1fb0646e4 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4e..a28c6717dd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3f..b91de832e4 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt index 75a3670761..796a9b2402 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv1d_bwd_data_instance device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 49dfc01fd9..2da5155117 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_BWD_DATA_INSTANCES) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index ba0ca32517..04b313d075 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 670cd94fc9..4304d8996c 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index 68d5f582fd..40a6b1ff09 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_add_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt index db92208fd7..ec4a8a2864 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv3d_bwd_data_instance device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index fe85bb7ead..298da1fbef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index bbf81a5fa2..04ae90bc5b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 63b4a00c99..45d6abce01 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index eb9345cbad..d859078ca9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 969361de9a..043bdab001 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 97693a2566..b9aeb6a6db 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index c10d4773a7..e6ca64cdc1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_silu_instance device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index ccada3a85e..f29943d93b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 426edeed74..61892e708c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 17d27ab150..2f45401ec6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 6cbd7528e7..aba9806a74 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_MULTIPLY_ADD_INSTANCES) list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 2b2cf8c774..7ee3efe7f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 059b6a720f..dac86d7707 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_SPLITK_INSTANCES) list(APPEND GEMM_SPLITK_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index 8dd0112a6b..c854b16eeb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index cfd829f87e..ab4313d89e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index f51a484bb5..ca4ea515bb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 93d5bd7422..ad430340ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library( device_grouped_conv2d_bwd_data_instance xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 8a896b06c7..340ddfb3f0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 2715a8cf21..1d3c3747d3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_instance #xdl # GNHWC, GKYXC, GNHWK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 836e671bf2..29fa8fa3c5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index e1cb975291..ae6dcb9880 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index b7901a2815..fa48f0edcc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 968e8dea2f..8b89dcf7ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 3825b92af4..972fb54031 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 49706588d6..436c37fd58 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 45d270d554..f36d55d367 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 08fb23afc9..1076249447 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index ae89caaeef..1be1db7d1d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index de7537af47..2625e6cbe8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index ef8a440c1a..167dfa9a6f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 648f2146cb..8e9693e691 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index ac22543bef..1ee3d0add4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index c22a6e9e96..5d50902be8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 11ae285167..cb6ffbec6c 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -2,19 +2,6 @@ set(PROFILER_SOURCES profiler.cpp profile_gemm.cpp - profile_gemm_splitk.cpp - profile_gemm_bias_add_reduce.cpp - profile_gemm_add_multiply.cpp - profile_gemm_multiply_add.cpp - profile_gemm_reduce.cpp - profile_batched_gemm.cpp - profile_batched_gemm_reduce.cpp - profile_conv_fwd.cpp - profile_conv_fwd_bias_relu.cpp - profile_conv_fwd_bias_relu_add.cpp - profile_conv_bwd_data.cpp - profile_grouped_conv_fwd.cpp - profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp @@ -29,36 +16,57 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp ) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) + list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + +endif() + +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +endif() + if(DL_KERNELS) list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) -endif() - -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -68,25 +76,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) @@ -96,39 +85,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() - +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +endif() if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0f90256c0..720ab468ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) @@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 9482821b68..759cf3da67 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) +add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/batched_gemm/test_batched_gemm.cpp b/test/batched_gemm/test_batched_gemm_xdl.cpp similarity index 100% rename from test/batched_gemm/test_batched_gemm.cpp rename to test/batched_gemm/test_batched_gemm_xdl.cpp diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 03f1d3a4eb..2b3288ef9d 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_gemm) - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_gemm) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) +endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp rename to test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 32c6ee85d1..c5868e4d7a 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,11 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - set(target 1) - endif() +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) endif() -endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp rename to test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c011a6a3c5..c042d7e000 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm) - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) +endif() diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp rename to test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3164863eef..2e09073540 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,29 +1,21 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm_permute) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - endif() - - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +add_custom_target(test_batched_gemm_softmax_gemm_permute) +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) +endif() diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index a86e72fddb..3ba0d82f0e 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,13 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) - add_gtest_executable(test_contraction test_contraction.cpp) - target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) - target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - set(target 1) - endif() +if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) + add_gtest_executable(test_contraction test_contraction_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) endif() -endforeach() + add_gtest_executable(test_contraction_interface test_contraction_interface_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + endif() +endif() diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface_xdl.cpp similarity index 100% rename from test/contraction/test_contraction_interface.cpp rename to test/contraction/test_contraction_interface_xdl.cpp diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction_xdl.cpp similarity index 100% rename from test/contraction/test_contraction.cpp rename to test/contraction/test_contraction_xdl.cpp diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index f734b46f53..e68a9b243c 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +add_gtest_executable(test_convnd_bwd_data convnd_bwd_data_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp similarity index 100% rename from test/convnd_bwd_data/convnd_bwd_data.cpp rename to test/convnd_bwd_data/convnd_bwd_data_xdl.cpp diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 745aceffc9..ba6d16a0d5 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) +add_gtest_executable(test_convnd_fwd convnd_fwd_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) - set(target 1) - endif() -endforeach() +endif() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp similarity index 100% rename from test/convnd_fwd/convnd_fwd.cpp rename to test/convnd_fwd/convnd_fwd_xdl.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 7df3f90abc..ab4c781847 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,11 +1,19 @@ -add_gtest_executable(test_gemm_add test_gemm_add.hpp) -target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) -target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) -target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) -target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_fastgelu.cpp rename to test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index c1c55140a0..1b12ab7528 100644 --- a/test/gemm_add/test_gemm_add_fastgelu.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddFastgelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_relu.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_relu.cpp rename to test/gemm_add/test_gemm_add_relu_xdl.cpp index ba6aab36bd..e8b769b1cb 100644 --- a/test/gemm_add/test_gemm_add_relu.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddRelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_silu.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_silu.cpp rename to test/gemm_add/test_gemm_add_silu_xdl.cpp index d4dd6fa38b..75fa59a8e7 100644 --- a/test/gemm_add/test_gemm_add_silu.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddSilu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add.hpp b/test/gemm_add/test_gemm_add_xdl.hpp similarity index 100% rename from test/gemm_add/test_gemm_add.hpp rename to test/gemm_add/test_gemm_add_xdl.hpp diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index bfc4404bd8..d1102a561a 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_gemm_layernorm) - add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) - set(target 1) - endif() - endif() -endforeach() +add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_gemm_layernorm) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +endif() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp similarity index 100% rename from test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp rename to test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 42a53c3048..121ecde609 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) endif() \ No newline at end of file diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/gemm_reduce/gemm_reduce_fp16.cpp rename to test/gemm_reduce/gemm_reduce_fp16_xdl.cpp diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index caf30fca59..4b66dddef9 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) +add_gtest_executable(test_gemm_splitk test_gemm_splitk_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) - set(target 1) endif() -endforeach() diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk_xdl.cpp similarity index 100% rename from test/gemm_split_k/test_gemm_splitk.cpp rename to test/gemm_split_k/test_gemm_splitk_xdl.cpp diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 305c568ee9..3507989bae 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,19 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index d7d6f8a3d6..34cdc63cd9 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,20 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight_xdl_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1ce878d5ca..4f245d63cd 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,8 +1,14 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) -target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +endif() -add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 8c57b667e2..f47685cf91 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,14 +1,13 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_grouped_gemm) - add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) - add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) - - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) - set(target 1) - endif() -endforeach() +add_custom_target(test_grouped_gemm) + +add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +endif() + +add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_interface) +endif() diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_interface.cpp rename to test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_splitk.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index 1b6decfed7..65f33da74d 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_data) add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index f3579aad08..afb78dc58e 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt index be6aaf94aa..d63cb79910 100644 --- a/test/permute_scale/CMakeLists.txt +++ b/test/permute_scale/CMakeLists.txt @@ -1,6 +1,4 @@ add_custom_target(test_permute) add_gtest_executable(test_permute_scale test_permute_scale.cpp) -if(result EQUAL 0) - target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) - add_dependencies(test_permute test_permute_scale) -endif() +target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) +add_dependencies(test_permute test_permute_scale) diff --git a/test/transpose/CMakeLists.txt b/test/transpose/CMakeLists.txt index 530cc9d72d..fb9379bea9 100644 --- a/test/transpose/CMakeLists.txt +++ b/test/transpose/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_transpose test_transpose.cpp) - target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) - set(target 1) - endif() -endforeach() +add_gtest_executable(test_transpose test_transpose_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) +endif() diff --git a/test/transpose/test_transpose.cpp b/test/transpose/test_transpose_xdl.cpp similarity index 100% rename from test/transpose/test_transpose.cpp rename to test/transpose/test_transpose_xdl.cpp diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 383707828c..1eb6c35db2 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -12,10 +12,8 @@ add_dependencies(test_wrapper test_wrapper_copy) add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) target_link_libraries(test_wrapper_partition PRIVATE utility) add_dependencies(test_wrapper test_wrapper_partition) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) +add_gtest_executable(test_wrapper_gemm test_wrapper_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_wrapper_gemm PRIVATE utility) add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm_xdl.cpp similarity index 100% rename from test/wrapper/test_wrapper_gemm.cpp rename to test/wrapper/test_wrapper_gemm_xdl.cpp From 9a194837af0e0d71399d751d9a30f5b6ee4843ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 3 Apr 2024 00:23:49 +0200 Subject: [PATCH 3/7] Introduce combined elementwise ops (#1217) * Introduce combined elementwise ops * Introduce refrence elementwise --- example/44_elementwise_permute/CMakeLists.txt | 2 + .../elementwise_binary_4D_fp16.cpp | 140 +++++ .../elementwise_permute.cpp | 67 +-- .../elementwise_permute_3d.cpp | 51 +- .../elementwise_permute_4D_fp16.cpp | 54 +- .../elementwise_permute_4D_fp16_2d.cpp | 56 +- .../elementwise_permute_4D_fp16_col.cpp | 87 ++- .../elementwise_permute_4D_fp16_row.cpp | 73 +-- .../elementwise_permute_4D_fp32_col.cpp | 85 +-- .../elementwise_permute_4D_fp32_row.cpp | 72 +-- .../elementwise_trinary_4D_fp16.cpp | 156 +++++ .../element/binary_element_wise_operation.hpp | 104 ++++ .../combined_element_wise_operation.hpp | 103 ++++ .../element/unary_element_wise_operation.hpp | 248 +++++++- ...idwise_elementwise_dynamic_vector_dims.hpp | 16 +- include/ck/utility/math_v2.hpp | 556 +++++++++++++++++- .../cpu/reference_elementwise.hpp | 110 ++++ .../profiler/profile_permute_scale_impl.hpp | 22 +- 18 files changed, 1694 insertions(+), 308 deletions(-) create mode 100644 example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp create mode 100644 example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index a963399dc7..3cf4812509 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) +add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp) if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp new file mode 100644 index 0000000000..8819bb65e6 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 +using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise:: + BinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + BinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 2> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp index 24e161c6d3..d3c3085eb8 100644 --- a/example/44_elementwise_permute/elementwise_permute.cpp +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -51,32 +39,7 @@ int main() std::vector ncdhw = {16, 8, 8, 8, 8}; std::vector ndhwc = {16, 8, 8, 8, 8}; - Tensor a(ncdhw); - Tensor b(ndhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - /**std::array a_strides = { - static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[4]), - 1}; - std::array b_strides = { - static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), - 1, - static_cast(ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[4])};**/ std::array a_strides = { static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), @@ -93,6 +56,20 @@ int main() 1}; ck::ranges::copy(ncdhw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -126,10 +103,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index f3aca57c35..47d8c4de65 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<4>, // InScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -59,10 +47,13 @@ int main() const int W = 5; const int D = 16; - std::vector ncdhw = {N, C, D, H, W}; - std::vector ndhwc = {N, D, H, W, C}; - Tensor a(ncdhw); - Tensor b(ndhwc); + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -74,10 +65,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, C, H, W, D}; - std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W - std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,11 +81,12 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * + ab_lengths[3] * ab_lengths[4]; std::size_t num_btype = - sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + - sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -111,10 +99,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 1b28a901cb..3ea1aa4bf8 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +44,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -77,9 +54,22 @@ int main() 1, static_cast(nhwc[2] * nhwc[3]), static_cast(nhwc[3])}; - ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -111,10 +101,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 30231a3758..1747e6dd8b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - const std::vector& shape_nchw, - Functor functor) -{ - for(std::size_t n = 0; n < shape_nchw[0]; ++n) - for(std::size_t c = 0; c < shape_nchw[1]; ++c) - for(std::size_t h = 0; h < shape_nchw[2]; ++h) - for(std::size_t w = 0; w < shape_nchw[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -54,13 +40,16 @@ int main() const int N = 120; const int C = 128; const int H = 32; - const int W = 1024; + const int W = 32; - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; + std::array ab_lengths{N, H, W, C}; - Tensor a(nchw); - Tensor b(nhwc); + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -72,11 +61,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, H, W, C}; - - std::array a_strides = {C * H * W, W, 1, H * W}; - std::array b_strides = {H * W * C, W * C, C, 1}; - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,10 +78,11 @@ int main() float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t flop = + std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]; - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -110,11 +95,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); - Tensor host_b(nhwc); - host_elementwise4D, Tensor, PassThrough>( - host_b, a, nchw, PassThrough{}); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index f832601f07..13c67fce05 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -6,9 +6,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -21,11 +23,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; @@ -60,8 +48,21 @@ int main() std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; - Tensor a(nchw); - Tensor b(nhwc); + std::array ab_lengths; + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -84,22 +85,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -113,11 +106,10 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t num_btype = + (2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -129,10 +121,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index bae85f53c1..0a0f6fec10 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,29 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +112,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index fe7acd3010..fc664186be 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; bool time_kernel = true; - std::vector nchw = {5, 4, 2, 3}; - std::vector nhwc = {5, 2, 3, 4}; - Tensor a(nchw); - Tensor b(nhwc); + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; @@ -84,22 +86,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -129,10 +123,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index aebdb37d9b..a0c416318a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,28 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +111,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp new file mode 100644 index 0000000000..050300eed2 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2 +using TrinaryAddUnaryScaleSquare = + ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + TrinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 3> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor& a2 = as[2]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + float gamma = 4.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a2.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + a2_device_buf.ToDevice(a2.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + a2_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "A2 (nchw): " << a2.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + ref_invoker.Run(ref_argument); + + const double threshold = std::pow(2, -10) * 2; + b_device_buf.FromDevice(b.mData.data()); + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index ba2e0057d9..f6e57aad09 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -92,6 +92,110 @@ struct Add }; }; +struct Max +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; + +struct Multiply +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + struct ScaleAdd { __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100644 index 0000000000..6d1d6b57c5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 9c64ad4dfa..1add81e69e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -12,10 +12,6 @@ namespace ck { namespace tensor_operation { namespace element_wise { -#if CK_WORKAROUND_SWDEV_383542 -extern "C" __device__ float __ocml_native_recip_f32(float); -#endif - struct PassThroughPack2 { template @@ -449,11 +445,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __expf(u); -#if !CK_WORKAROUND_SWDEV_383542 - y = x * __frcp_rn(1.f + emu); -#else - y = x * __ocml_native_recip_f32(1.f + emu); -#endif + y = x * ck::math::rcp(1.f + emu); } template <> @@ -559,6 +551,244 @@ struct TanH }; }; +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::rcp(x); + }; +}; + struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp index 2a906a1432..4d1a09b445 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -118,8 +118,16 @@ struct GridwiseElementwise __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); const index_t m1_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); - const auto thread_grid_offset = - make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); using ThisThreadBlock = ThisThreadBlock; // If src and dst have same vector dim, then: @@ -157,9 +165,9 @@ struct GridwiseElementwise uniform_sequence_gen_t, uniform_sequence_gen_t, uniform_sequence_gen_t>{in_grid_desc_tuple, - thread_grid_offset, + input_thread_grid_offset, out_grid_desc_tuple, - thread_grid_offset, + output_thread_grid_offset, elementwise_op}; global_to_global_transfer.Run( in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index a07fde3da3..2b921cdc7c 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,10 @@ namespace ck { namespace math { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + // math functions for the host, some are implemented by calling C++ std functions static inline __host__ float abs(float x) { return std::abs(x); }; @@ -111,6 +115,276 @@ inline __host__ double tanh(double x) return std::tanh(x); }; +template +inline __host__ T acos(T x) +{ + return ck::type_convert(std::acosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acos(float x) +{ + return std::acosf(x); +}; + +template <> +inline __host__ double acos(double x) +{ + return std::acos(x); +}; + +template +inline __host__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __host__ float neg(float x) +{ + return -x; +}; + +template <> +inline __host__ double neg(double x) +{ + return -x; +}; + +template <> +inline __host__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __host__ int8_t neg(int8_t x) +{ + return -x; +}; + +template +inline __host__ T atan(T x) +{ + return ck::type_convert(std::atanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atan(float x) +{ + return std::atanf(x); +}; + +template <> +inline __host__ double atan(double x) +{ + return std::atan(x); +}; + +template +inline __host__ T sin(T x) +{ + return ck::type_convert(std::sinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sin(float x) +{ + return std::sinf(x); +}; + +template <> +inline __host__ double sin(double x) +{ + return std::sin(x); +}; + +template +inline __host__ T asin(T x) +{ + return ck::type_convert(std::asinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asin(float x) +{ + return std::asinf(x); +}; + +template <> +inline __host__ double asin(double x) +{ + return std::asin(x); +}; + +template +inline __host__ T asinh(T x) +{ + return ck::type_convert(std::asinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asinh(float x) +{ + return std::asinhf(x); +}; + +template <> +inline __host__ double asinh(double x) +{ + return std::asinh(x); +}; + +template +inline __host__ T cos(T x) +{ + return ck::type_convert(std::cosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cos(float x) +{ + return std::cosf(x); +}; + +template <> +inline __host__ double cos(double x) +{ + return std::cos(x); +}; + +template +inline __host__ T acosh(T x) +{ + return ck::type_convert(std::acoshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acosh(float x) +{ + return std::acoshf(x); +}; + +template <> +inline __host__ double acosh(double x) +{ + return std::acosh(x); +}; + +template +inline __host__ T tan(T x) +{ + return ck::type_convert(std::tanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float tan(float x) +{ + return std::tanf(x); +}; + +template <> +inline __host__ double tan(double x) +{ + return std::tan(x); +}; + +template +inline __host__ T atanh(T x) +{ + return ck::type_convert(std::atanhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atanh(float x) +{ + return std::atanhf(x); +}; + +template <> +inline __host__ double atanh(double x) +{ + return std::atanh(x); +}; + +template +inline __host__ T sinh(T x) +{ + return ck::type_convert(std::sinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sinh(float x) +{ + return std::sinhf(x); +}; + +template <> +inline __host__ double sinh(double x) +{ + return std::sinh(x); +}; + +template +inline __host__ T ceil(T x) +{ + return ck::type_convert(std::ceilf(ck::type_convert(x))); +}; + +template <> +inline __host__ float ceil(float x) +{ + return std::ceilf(x); +}; + +template <> +inline __host__ double ceil(double x) +{ + return std::ceil(x); +}; + +template +inline __host__ T cosh(T x) +{ + return ck::type_convert(std::coshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cosh(float x) +{ + return std::coshf(x); +}; + +template <> +inline __host__ double cosh(double x) +{ + return std::cosh(x); +}; + +template +inline __host__ T floor(T x) +{ + return ck::type_convert(std::floorf(ck::type_convert(x))); +}; + +template <> +inline __host__ float floor(float x) +{ + return std::floorf(x); +}; + +template <> +inline __host__ double floor(double x) +{ + return std::floor(x); +}; + +template +inline __host__ T rcp(T x) +{ + return ck::type_convert(1.f / ck::type_convert(x)); +}; + template inline __host__ T exp(T x) { @@ -282,6 +556,286 @@ inline __device__ double tanh(double x) return ::tanh(x); }; +template +inline __device__ T acos(T x) +{ + return ck::type_convert(::acosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acos(float x) +{ + return ::acosf(x); +}; + +template <> +inline __device__ double acos(double x) +{ + return ::acos(x); +}; + +template +inline __device__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __device__ float neg(float x) +{ + return -x; +}; + +template <> +inline __device__ double neg(double x) +{ + return -x; +}; + +template <> +inline __device__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __device__ int8_t neg(int8_t x) +{ + return -x; +}; + +template <> +inline __device__ half_t neg(half_t x) +{ + return __hneg(x); +}; + +template +inline __device__ T atan(T x) +{ + return ck::type_convert(::atanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atan(float x) +{ + return ::atanf(x); +}; + +template <> +inline __device__ double atan(double x) +{ + return ::atan(x); +}; + +template +inline __device__ T sin(T x) +{ + return ck::type_convert(::sinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sin(float x) +{ + return ::sinf(x); +}; + +template <> +inline __device__ double sin(double x) +{ + return ::sin(x); +}; + +template <> +inline __device__ half_t sin(half_t x) +{ + return ::hsin(x); +}; + +template +inline __device__ T asin(T x) +{ + return ck::type_convert(::asinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asin(float x) +{ + return ::asinf(x); +}; + +template <> +inline __device__ double asin(double x) +{ + return ::asin(x); +}; + +template +inline __device__ T asinh(T x) +{ + return ck::type_convert(::asinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asinh(float x) +{ + return ::asinhf(x); +}; + +template <> +inline __device__ double asinh(double x) +{ + return ::asinh(x); +}; + +template +inline __device__ T acosh(T x) +{ + return ck::type_convert(::acoshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acosh(float x) +{ + return ::acoshf(x); +}; + +template <> +inline __device__ double acosh(double x) +{ + return ::acosh(x); +}; + +template +inline __device__ T tan(T x) +{ + return ck::type_convert(::tanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float tan(float x) +{ + return ::tanf(x); +}; + +template <> +inline __device__ double tan(double x) +{ + return ::tan(x); +}; + +template +inline __device__ T atanh(T x) +{ + return ck::type_convert(::atanhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atanh(float x) +{ + return ::atanhf(x); +}; + +template <> +inline __device__ double atanh(double x) +{ + return ::atanh(x); +}; + +template +inline __device__ T sinh(T x) +{ + return ck::type_convert(::sinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sinh(float x) +{ + return ::sinhf(x); +}; + +template <> +inline __device__ double sinh(double x) +{ + return ::sinh(x); +}; + +template +inline __device__ T ceil(T x) +{ + return ck::type_convert(::ceilf(ck::type_convert(x))); +}; + +template <> +inline __device__ float ceil(float x) +{ + return ::ceilf(x); +}; + +template <> +inline __device__ double ceil(double x) +{ + return ::ceil(x); +}; + +template <> +inline __device__ half_t ceil(half_t x) +{ + return ::hceil(x); +}; + +template +inline __device__ T cosh(T x) +{ + return ck::type_convert(::coshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cosh(float x) +{ + return ::coshf(x); +}; + +template <> +inline __device__ double cosh(double x) +{ + return ::cosh(x); +}; + +template +inline __device__ T floor(T x) +{ + return ck::type_convert(::floorf(ck::type_convert(x))); +}; + +template <> +inline __device__ float floor(float x) +{ + return ::floorf(x); +}; + +template <> +inline __device__ double floor(double x) +{ + return ::floor(x); +}; + +template <> +inline __device__ half_t floor(half_t x) +{ + return ::hfloor(x); +}; + +template +inline __device__ T rcp(T x) +{ +#if !CK_WORKAROUND_SWDEV_383542 + return __frcp_rn(x); +#else + return __ocml_native_recip_f32(x); +#endif +}; + template inline __device__ T exp(T x) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp new file mode 100644 index 0000000000..470641fff7 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceElementwise : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + : a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op} + { + } + + const std::array, NumATensors>& a_tensors_; + Tensor& b_tensor_; + ElementOp element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceElementwise::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumATensors == 1) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx)); + }); + } + else if constexpr(NumATensors == 2) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx)); + }); + } + else if constexpr(NumATensors == 3) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), + arg.a_tensors_[0](idx), + arg.a_tensors_[1](idx), + arg.a_tensors_[2](idx)); + }); + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + { + return Argument{a_tensors, b_tensor, element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceElementwise" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index c69e36142d..186a24501e 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -14,6 +14,8 @@ #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -21,14 +23,6 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template -void reference_permute_scale(HostTensorB& b_tensor, - const HostTensorA& a_tensor, - ElementOp tensor_op) -{ - b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); -} - namespace profiler { template @@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; - Tensor a(lengths_vector, input_strides_vector); + std::array, 1> as = {Tensor(lengths_vector, input_strides_vector)}; + Tensor& a = as[0]; Tensor b(lengths_vector, output_strides_vector); Tensor host_b(lengths_vector, output_strides_vector); @@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{scale}); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, ElementOp>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, ElementOp{scale}); + + ref_invoker.Run(ref_argument); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; From a61e73bc56966a138ab1b5dadf27983800788431 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:08:08 -0500 Subject: [PATCH 4/7] Add instances for conv_scale with fp8@bf8->fp8 (#1220) * Update device op api to support BComputeType * Add example * Add instances * Add profiler mode * Add client example * Update copyright year * Add BComputeType check * Fix compute types --- client_example/16_convnd_fwd/CMakeLists.txt | 5 + client_example/16_convnd_fwd/common.hpp | 8 +- .../16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp | 50 ++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + .../09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp | 83 ++++++++++++ .../device_grouped_conv_fwd_multiple_abd.hpp | 10 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 15 ++- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 10 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 39 +++--- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 7 +- .../device_grouped_conv_fwd_xdl_instance.hpp | 36 ++++++ .../gpu/grouped_convolution_forward.hpp | 118 ++++++++++++------ .../gpu/grouped_convolution_forward_xdl.inc | 18 +++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ..._ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp | 54 ++++++++ .../profile_grouped_conv_fwd_impl.hpp | 10 +- profiler/src/profile_grouped_conv_fwd.cpp | 75 ++++++----- 17 files changed, 441 insertions(+), 103 deletions(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e034c468d5..808693b632 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) +endif() + if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a5b7c5b42e..ee408c7443 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -95,7 +95,8 @@ template + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array; + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp new file mode 100644 index 0000000000..8508dc9c55 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 61e9a43c3a..afbe741212 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -5,6 +5,7 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000..53a12377c5 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index fa3dcfdf20..31e8d639ad 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval().IsTuple()); * \tparam AElementwiseOperation A elementwise operation. * \tparam BElementwiseOperation B elementwise operation. * \tparam CDEElementwiseOperation CDE elementwise operation. - * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). + * \tparam BComputeType Compute data type for B tensor (default: AComputeType). */ template ::value, Number<0>, - ADataType>())> // ComputeType is InputType by default (first + ADataType>()), // AComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed + typename BComputeType = AComputeType> struct DeviceGroupedConvFwdMultipleABD : public BaseOperator { static constexpr bool isMultiA = is_detected::value; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 5ff42f98f3..f53ec8a4e8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -254,13 +254,14 @@ template ::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD + AComputeDataType, + BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; @@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmBDataType = std::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ - GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ @@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, @@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - ComputeDataType, + AComputeDataType, + BComputeDataType, LoopSched>; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 4b7cc56796..0f98f9e63d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,7 +30,7 @@ namespace ck { // D0, D1, ... and E have the same layout template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleABD_xdl_cshuffle { static constexpr index_t NumATensor = AsDataType::Size(); @@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle decltype(GridwiseGemmPipeline_Selector())>; #if CK_WORKAROUND_DENORM_FIX - using ComputeDataType = - conditional_t, ck::bhalf_t, ComputeDataType_>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using ComputeDataType = ComputeDataType_; + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, AsDataType, - Tuple, + Tuple, decltype(as_grid_desc_ak0_m_ak1), decltype(tie(a_block_desc_ak0_m_ak1)), AElementwiseOperation, @@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, BsDataType, - Tuple, + Tuple, decltype(bs_grid_desc_bk0_n_bk1), decltype(tie(b_block_desc_bk0_n_bk1)), BElementwiseOperation, @@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeDataType, // ComputeDataType for A - ComputeDataType, // ComputeDataType for B + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index c0a3d29f85..6ddc3aca18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,7 +73,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle #if CK_WORKAROUND_DENORM_FIX using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 0f845ca1ed..40878e4f0e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 24a5f9a5cb..e61ec28284 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -34,7 +34,8 @@ template + typename AComputeType, + typename BComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -75,14 +78,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -94,14 +99,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -115,14 +122,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } @@ -130,14 +139,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); } @@ -149,14 +161,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -164,7 +178,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } @@ -176,14 +192,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -191,7 +209,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } @@ -203,14 +223,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); } @@ -218,14 +240,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); } @@ -237,7 +262,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } @@ -245,28 +271,40 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( op_ptrs); } if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs); } #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } @@ -274,14 +312,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } @@ -295,7 +336,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); @@ -305,7 +347,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); @@ -320,7 +363,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); @@ -335,7 +379,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( @@ -347,7 +392,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); @@ -363,7 +409,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( @@ -375,7 +422,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 942674ef99..691414ebcb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -351,6 +351,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( BF8>>>& instances); #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 972fb54031..50a6ec9a45 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -41,4 +41,9 @@ if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp new file mode 100644 index 0000000000..d42104bf6b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index f629809daa..d913873305 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,7 +31,9 @@ template + typename OutDataType, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool profile_grouped_conv_fwd_impl(int do_verification, int init_method, bool do_log, @@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>; + OutElementOp, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 1f72733729..a847999b56 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -25,6 +25,7 @@ enum struct ConvDataType INT8_INT8_INT8, // 3 F8_F8_F8, // 4 BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 }; #define OP_NAME "grouped_conv_fwd" @@ -40,7 +41,8 @@ static void print_helper_msg() << " 2: Input bf16, Weight bf16, Output bf16\n" << " 3: Input int8, Weight int8, Output int8\n" << " 4: Input fp8, Weight fp8, Output fp8\n" - << " 5: Input bf8, Weight bf8, Output fp8)\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) auto out_layout, auto in_type, auto wei_type, - auto out_type) { + auto out_type, + auto a_compute_type, + auto b_compute_type) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using InLayout = decltype(in_layout); @@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using WeiDataType = decltype(wei_type); using OutDataType = decltype(out_type); + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + OutDataType, + AComputeType, + BComputeType>( do_verification, init_method, do_log, time_kernel, params); return pass ? 0 : 1; @@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } // NHWGC_GKYXC_NHWGK @@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } else if(data_type == ConvDataType::F8_F8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{}); } else if(data_type == ConvDataType::BF8_BF8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{}); } } From c701071666ce5656c8bd4331979f56fcc497fda6 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Apr 2024 11:01:33 +0200 Subject: [PATCH 5/7] Add Grouped Gemm Multiple D SplitK TwoStage (#1212) * Support A/B/C elementwise ops. * First part of GGEMM multiD splitk two stage. * WIP - changes for debuggin. * tmp save * working version * added bf16@int8 version * fixes * add reviewers sugestions * pre-commited missing files * switched to ifs from elseifs --------- Co-authored-by: Adam Osewski --- ...rouped_gemm_multiple_d_splitk_xdl_fp16.cpp | 394 +++++++ .../device_grouped_gemm_multiple_d_splitk.hpp | 136 +++ .../device/impl/device_elementwise_impl.hpp | 16 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 987 ++++++++++++++++++ ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 27 +- .../cpu/reference_gemm_multiple_d.hpp | 175 ++++ .../gpu/grouped_gemm.hpp | 50 +- .../gpu/grouped_gemm/CMakeLists.txt | 2 + ...o_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp | 99 ++ ...wo_stage_f16_f16_f16_mk_kn_mn_instance.cpp | 96 ++ .../profile_grouped_gemm_two_stage_impl.hpp | 366 +++++++ profiler/src/CMakeLists.txt | 1 + .../src/profile_grouped_gemm_two_stage.cpp | 157 +++ 13 files changed, 2490 insertions(+), 16 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_two_stage.cpp diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp new file mode 100644 index 0000000000..ecff7b4713 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDMatrices = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 128; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDMatrices>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDMatrices> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + problem_size.stride_Ds[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + gemm.SetKBatchSize(argument, config.k_batch); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = argument.gemm_kernel_args_[i].karg_; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{})); + + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(), + c_device_result_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 11) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg10: k_batch (> 0)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[10]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp new file mode 100644 index 0000000000..d91eac0730 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmMultipleDKernelArguments +{ + __host__ __device__ + GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp index 37867f1eaa..1a44c3ed9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,10 +22,12 @@ namespace device { template + index_t NumDim, // The max dim of input tensors + // the tensors descs have to be aligned, such that + // the innermost dim is the contiguous one. + index_t MPerThread, // How many elements per thread to read + typename InScalarPerVectorSeq, // Scalar per vec for each Input + typename OutScalarPerVectorSeq> // Scalar per vec for each Output struct DeviceElementwiseImpl : public DeviceElementwise { @@ -242,13 +244,13 @@ struct DeviceElementwiseImpl static_for<0, NumInput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); static_for<0, NumOutput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); return valid; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp new file mode 100644 index 0000000000..2d60c027bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -0,0 +1,987 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template = false> +struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + : public DeviceGroupedGemmMultipleDSplitK +{ + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1 + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using WorkspaceDataType = float; + + // First stage GridwiseGEMM kernel. + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + WorkspaceDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, // CElementwiseOperation + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeDataType>; + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&]([[maybe_unused]] auto i) constexpr { + return Number{}; + }, + Number{}); + } + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using DsGridPointer = decltype(MakeDsGridPointer()); + using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); + using CDDataTypes = decltype(concat_tuple(ck::Tuple{}, DsGridPointer{})); + + using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence()); + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + ElementwiseInputSequence, + ck::Sequence, + true>; + + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using GemmKernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + GemmKernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op, + DefaultKBatch) + { + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t kbatch) + : K_BATCH{kbatch}, + group_count_{0}, + skipped_group_count_{0}, + grid_size_{0}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + p_Ds_{p_Ds} + { + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size"); + } + + gemm_kernel_args_.reserve(group_count_); + elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_d_grid_descs_m_n_.reserve(group_count_); + ds_grid_pointer_.reserve(group_count_); + group_grid_size_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M * N * K == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_e = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + + DsGridDesc_M_N ds_grid_desc_m_n; + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + group_grid_size_[i] = grid_size_grp; + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + std::array stride_ds; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + stride_ds[j] = gemm_descs[i].stride_Ds_[j]; + }); + stride_Ds_.emplace_back(std::move(stride_ds)); + + // We first set E pointer to actual operation output, but later on + // when workspace will be set, this will be updated to workspace memory. + auto karg = GemmKernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_e, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); + ds_grid_pointer_.push_back(p_ds_grid); + } + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_ = p_Es; + } + + /** + * @brief Set new kbatch value. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + group_grid_size_[i] = grid_size_grp; + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + +#if DEBUG_LOG + index_t tiles = (block_end - block_start) / K_BATCH; + std::cout << "block_start: " << block_start << "\n" + << "block_end: " << block_end << "\n" + << "tiles: " << tiles << std::endl + << std::endl; + + std::cout << "KPadded: " << karg.KPadded << std::endl + << "K0Padded: " << karg.K0Padded << std::endl + << "KBatch: " << karg.k_batch << std::endl + << "grid_size_: " << karg.KPadded << std::endl; +#endif + } + } + + void UpdateEPointers() + { + // set-up each group E pointer to it's designated workspace memory. + WorkspaceDataType* p_workspace = reinterpret_cast(p_workspace_); + std::size_t offset = 0; + + for(auto& arg : gemm_kernel_args_) + { + arg.karg_.p_c_grid = p_workspace + offset; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + offset += tiles * MPerBlock * NPerBlock; +#if DEBUG_LOG + std::cout << "block_start: " << arg.block_start_ << "\n" + << "block_end: " << arg.block_end_ << "\n" + << "tiles: " << tiles << "\n" + << "offset: " << offset << std::endl; +#endif + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + std::size_t size_bytes{0}; + + for(const auto& arg : gemm_kernel_args_) + { + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType); + } + return size_bytes; + } + + std::size_t GetWorkspaceSize(std::size_t group) const + { + const auto& arg = gemm_kernel_args_[group]; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + return tiles * MPerBlock * NPerBlock; + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + index_t grid_size_; + // Pointer to device memory with GEMM kernel arguments. + const void* p_dev_gemm_args_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector>& p_Ds_; + std::vector> stride_Ds_; + std::vector gemm_kernel_args_; + std::vector group_grid_size_; + + std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_d_grid_descs_m_n_; + std::vector ds_grid_pointer_; + std::vector e_ptrs_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary + /// workspace. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config = StreamConfig{}) + { + auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = + CheckArgument(arg, stream_config); + + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(dev_gemm_workspace == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + + if(all_have_main_k_block_loop) + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + else + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see + /// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly + /// allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.p_workspace_ == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const + { + bool all_have_kbatch_gt_one, all_have_main_k_block_loop; + + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1( + arg.gemm_kernel_args_[0].karg_.M, + arg.gemm_kernel_args_[0].karg_.MPadded, + arg.gemm_kernel_args_[0].karg_.K, + arg.gemm_kernel_args_[0].karg_.StrideA, + arg.gemm_kernel_args_[0].karg_.k_batch, + arg.gemm_kernel_args_[0].karg_.K0Padded, + arg.gemm_kernel_args_[0].karg_.KPadded); + + all_have_kbatch_gt_one = arg.K_BATCH > 1; + all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + } + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + + bool not_all_have_main_k_block_loop_same = + all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + bool not_all_have_kbatch_value_same = + all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); + + if(not_all_have_main_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << gemm_arg.k_batch + << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); + } + + template + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + const auto gemm_kernel = + kernel_grouped_gemm_xdl_splitk; + + const auto elementwise_kernel = kernel_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation>; + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_args, + dev_gemm_workspace, + stream_config); + } + + template + float LaunchKernel(const KernelFunction& gemm_kernel, + const KernelFunction2& elementwise_kernel, + const Argument& arg, + const void* dev_gemm_args, + [[maybe_unused]] void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + float time{0.f}; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + // GEMM kernel + time = launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + gemm_kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + PassThrough{}); + + // Elementwise kernels + for(int i = 0; i < arg.group_count_; ++i) + { + time += launch_and_time_kernel( + stream_config, + elementwise_kernel, + dim3(arg.group_grid_size_[i]), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + arg.elementwise_d_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), + arg.ds_grid_pointer_[i]), + type_convert(arg.e_ptrs_[i]), + Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0), + arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)}, + arg.cde_element_op_); + } + return time; + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); + if(not group_arg_valid) + { +#if DEBUG_LOG + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); +#endif // DEBUG_LOG + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpy(p_dev_kernel_args, + arg.gemm_kernel_args_.data(), + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + void SetWorkSpacePointer( + BaseArgument* p_arg, + void* p_workspace, + [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + p_arg_->UpdateEPointers(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index abee2fea53..a33d7d8fb5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,13 +26,19 @@ namespace device { template + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -64,10 +70,16 @@ __global__ void GridwiseGemm::template Run( gemm_desc_ptr[group_id].karg_, static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_); + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); #else ignore = gemm_descs_const; ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - + using PassThrough = ck::tensor_operation::element_wise::PassThrough; struct GemmTransKernelArg { KernelArgument karg_; @@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// assumption: every D matrix has the same layout and the same datatype +template +struct ReferenceGemmMultipleD : public device::BaseOperator +{ + using DDataType = remove_cvref_t>; + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + ds_m_n_{ds_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const std::array, DsDataType::Size()>& ds_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultipleD::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + ComputeTypeA v_a = 0; + ComputeTypeB v_b = 0; + + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + CDataType v_c = 0; + + if constexpr(DsDataType::Size() == 0) + { + arg.cde_element_op_(v_c, v_acc); + } + else if constexpr(DsDataType::Size() == 1) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n)); + } + else if constexpr(DsDataType::Size() == 2) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n)); + } + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_m_k, b_k_n, ds_m_n, c_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultipleD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index 056e906c27..d06a579811 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -146,6 +146,32 @@ void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); + template > op_ptrs; +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v) { @@ -190,6 +217,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -210,8 +239,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -228,6 +259,19 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + } +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 2625e6cbe8..5a50eca107 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -10,4 +10,6 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..8d3baf19ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d384842343 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp new file mode 100644 index 0000000000..41dcabbfcf --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_two_stage_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_host_results; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); +#if DEBUG_LOG + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; +#endif // DEBUG_LOG + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + auto p_ds = std::vector>{}; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + using DeviceOpSplitK = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + // skip non-splitk grouped_gemm + if(dynamic_cast(gemm_ptr.get()) == nullptr) + { + continue; + } + + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) + { + kbatch_list = {kbatch}; + } + + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch_curr); + + DeviceMem gemm_arg_dev_mem(dynamic_cast(gemm_ptr.get()) + ->GetDeviceKernelArgSize(argument_ptr.get())); + dynamic_cast(gemm_ptr.get()) + ->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + gemm_desc_workspace.SetZero(); + for(std::size_t i = 0; i < gemm_descs.size(); i++) + c_device_buf[i]->SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + if(std::is_same_v && kbatch_curr > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index cb6ffbec6c..e8992070b5 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) diff --git a/profiler/src/profile_grouped_gemm_two_stage.cpp b/profiler/src/profile_grouped_gemm_two_stage.cpp new file mode 100644 index 0000000000..17daf1e80c --- /dev/null +++ b/profiler/src/profile_grouped_gemm_two_stage.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 + BF16_INT8_BF16 // 1 +}; + +#define OP_NAME "grouped_gemm_two_stage" +#define OP_DESC "Grouped GEMM TwoStage" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm_two_stage(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp16; 1: bf16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + } + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage); From 7e5c81fed2737312f960cd41fe9afbc02669ce27 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Apr 2024 11:33:29 -0700 Subject: [PATCH 6/7] fix the latest errors with staging compiler (#1229) --- test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 720ab468ea..bbb75c49e8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ function(add_gtest_executable TEST_NAME) set(result ${result} PARENT_SCOPE) endfunction() +add_compile_options(-Wno-c++20-extensions) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) From 42ebffe822bc7d89eeef0160ac461b36b407a025 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 7 Apr 2024 23:11:29 +0000 Subject: [PATCH 7/7] 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg --- example/ck_tile/01_fmha/fmha_fwd.cpp | 44 +++- example/ck_tile/01_fmha/fmha_fwd.hpp | 188 +----------------- example/ck_tile/01_fmha/generate.py | 127 +++++++++--- example/ck_tile/01_fmha/mask.hpp | 56 ++++-- example/ck_tile/01_fmha/script/smoke_test.sh | 3 +- include/ck_tile/ops/fmha.hpp | 1 + .../ck_tile/ops/fmha/block/block_masking.hpp | 178 +++++++++++++++-- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 31 ++- .../pipeline/block_fmha_pipeline_enum.hpp | 17 ++ 9 files changed, 380 insertions(+), 265 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5a6afe36f6..0eb17f7b1b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -59,12 +59,13 @@ auto create_args(int argc, char* argv[]) .insert("operm", "1", "permute output") .insert("bias", "0", "add bias or not") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("mask", - "0", - "0: no mask, 1: top-left, 2:bottom-right\n" - "'t:l,r', top-left local-attn with left right size\n" - "'b:l,r', bottom-r local-attn with left right size\n" - "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert( + "mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left sliding window attn with left right size\n" + "'b:l,r', bottom-r sliding window attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") @@ -381,8 +382,9 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_bias, batch_stride_lse, batch_stride_o, - mask.y, - mask.x, + mask.left, + mask.right, + static_cast(mask.type), descale_q * descale_k, descale_v}; }(); @@ -498,12 +500,32 @@ bool run(const ck_tile::ArgParser& arg_parser) else if(mask.type == mask_enum::window_generic) { ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); } else { - ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); } if(lse) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9293201cd2..8ff13cfe13 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -80,177 +80,6 @@ struct FmhaMasks using CausalMask = ck_tile::GenericAttentionMask; }; -#if 0 -// internal API, don't use this directly -template -auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t nhead_k, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t max_seqlen_q, - float scale, - float descale_qk, - float descale_sv, - bool i_perm, - bool o_perm, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x) -{ - constexpr bool is_v_rowmajor = - std::is_same_v; - - assert(nhead % nhead_k == 0); - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, - /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' - /// are 0. - // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_k : nhead_k * seqlen_k; - }(); - const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_k : seqlen_k; - }(); - const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); - const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); - - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - mask_y, - mask_x, - descale_qk, - descale_sv); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask_y, - mask_x, - descale_qk, - descale_sv); - } - }(); - - dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v); - return ck_tile::make_tuple(kargs, grids); -} - -// This is the args from caller to underneath API, different from the kernel -struct fmha_fwd_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; - void* lse_ptr; - void* o_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; - ck_tile::index_t batch; - ck_tile::index_t nhead; - ck_tile::index_t nhead_k; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t max_seqlen_q; - float scale; - float descale_qk; - float descale_sv; - bool i_perm; - bool o_perm; - ck_tile::index_t mask_y; - ck_tile::index_t mask_x; -}; -#endif - // runtime args, some will passed to karg, some will used to compute grids/blocks struct fmha_fwd_args { @@ -289,8 +118,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; - ck_tile::index_t mask_y; - ck_tile::index_t mask_x; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; float descale_qk; float descale_sv; }; @@ -327,8 +157,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_bias, args.nhead_stride_lse, args.nhead_stride_o, - args.mask_y, - args.mask_x, + args.window_size_left, + args.window_size_right, + args.mask_type, args.descale_qk, args.descale_sv); } @@ -363,8 +194,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_bias, args.batch_stride_lse, args.batch_stride_o, - args.mask_y, - args.mask_x, + args.window_size_left, + args.window_size_right, + args.mask_type, args.descale_qk, args.descale_sv); } @@ -385,6 +217,7 @@ template ; static constexpr bool kHasBias = kHasBias_; static constexpr bool kStoreLse = kStoreLse_; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index e415974480..686dd35d19 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -24,6 +24,16 @@ DTYPE_BITS = { "bf8" : 8 } +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + MASK_MAP = { "no" : "FmhaMasks::NoMask", "causal" : "FmhaMasks::CausalMask", @@ -46,12 +56,17 @@ PIPELINE_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_fp8" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_FP8", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + BOOL_MAP = { "t" : "true", "f" : "false" } -MASKS = ["no", "causal", "generic"] DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder @@ -113,7 +128,8 @@ using fmha_kernel_{F_idx} = fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -149,17 +165,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ MASK_CHECK_MAP = { "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", "generic" : "t.mask_type == mask_enum::window_generic", } +MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ +def get_mask_map(mask : str): + if mask == "generic": + return MASK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_MAP + else: + assert False + return None + +def get_mask_check_map(mask : str): + if mask == "generic": + return MASK_CHECK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + @dataclass class FmhaFwdApiTrait: pipeline_tag : str @@ -193,14 +232,19 @@ class FmhaFwdApiTrait: if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0' + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @property def skcheck(self) -> str: - if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False @property def dcheck(self) -> str: @@ -209,7 +253,7 @@ class FmhaFwdApiTrait: if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0' + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {self.bk0blen} == 0' else: assert False @@ -220,7 +264,7 @@ class FmhaFwdApiTrait: if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0' + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {self.bk0blen} == 0' else: assert False @@ -251,13 +295,17 @@ class FmhaFwdPipeline: n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' if self.F_bias == 't' : n += '_bias' - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' return n class FmhaFwdApiPool: - def __init__(self): + def __init__(self, mask_impl): self.pool = dict() + self.mask_impl = mask_impl def register_traits(self, trait : FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? @@ -278,8 +326,9 @@ class FmhaFwdApiPool: inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], - F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -320,6 +369,7 @@ class FmhaFwdKernel: F_mode : str # value from MODE_MAP F_tile : FmhaFwdTileSize F_pipeline : FmhaFwdPipeline + mask_impl : str @property def template(self) -> str: @@ -347,8 +397,9 @@ class FmhaFwdKernel: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BOOL_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_occupancy = self.F_tile.F_occupancy , - F_mask = MASK_MAP[self.F_pipeline.F_mask], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @@ -403,14 +454,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask)) @@ -423,16 +477,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse kernels - for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]): + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) else: assert False return pipelines gen = list() - api_pool = FmhaFwdApiPool() + api_pool = FmhaFwdApiPool(mask_impl) for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) @@ -443,7 +500,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue @@ -458,24 +515,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None: +def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter) + api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None: +def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter) + _, kernels = get_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") @@ -504,8 +561,26 @@ if __name__ == "__main__": required=False, help="filter out kernels that need to generate, using fnmatch module" ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim" + ) + args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter) + list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter) + write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index d652172ede..526ea5dd04 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -9,11 +9,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +// keep this in sync with ck_tile::GenericAttentionMaskEnum enum class mask_enum { no_mask = 0, - causal_top_left, - causal_bottom_right, + mask_top_left, + mask_bottom_right, window_generic, }; @@ -21,18 +22,19 @@ struct mask_info { mask_enum type; ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right void serialize(std::ostream& os) const { if(type == mask_enum::no_mask) os << "n"; - else if(type == mask_enum::causal_top_left) - os << "tl"; - else if(type == mask_enum::causal_bottom_right) - os << "br"; + else if(type == mask_enum::mask_top_left) + os << "tl(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "br(" << left << ":" << right << ")"; else { - os << "g(" << y << "/" << x << ")"; + os << "g(" << y << ":" << x << ")"; } } static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) @@ -57,22 +59,30 @@ struct mask_info // TODO: some validation if(t == "t") { - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( v0, v1, y_total, x_total, true); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; } else if(t == "b") { - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( v0, v1, y_total, x_total, false); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; } else if(t == "g") { - tmp.y = v0; - tmp.x = v1; + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; } else { @@ -84,15 +94,19 @@ struct mask_info { // should be 0, 1, 2 tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::causal_top_left) + if(tmp.type == mask_enum::mask_top_left) { - tmp.y = seqlen_q; - tmp.x = 1; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; } - else if(tmp.type == mask_enum::causal_bottom_right) + else if(tmp.type == mask_enum::mask_bottom_right) { - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; } } return tmp; diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 012ea42df6..6b7bf8fe41 100644 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1e9acc6d7b..c567e63ddf 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index c256e08e46..39447ca99e 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -7,6 +7,20 @@ namespace ck_tile { +enum struct GenericAttentionMaskEnum +{ + NO_MASK = 0, + + // below enum could be causal, or sliding window + MASK_FROM_TOP_LEFT = 1, + MASK_FROM_BOTTOM_RIGHT = 2, + + // this enum maybe not used by xformer/FA, since it's hard to + // specify left/right window for varlen case. put it here for + // debug purpose + MASK_GENERIC, +}; + // clang-format off /* generic Attention Mask Coordinate use x(horizontal axis), y(vertical axis) to describe mask. @@ -188,6 +202,129 @@ struct GenericAttentionMask index_t y_total, x_total; }; +// clang-format off +namespace impl { + template struct SimplifiedMaskName; + template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version only have 2 variation: masking and non-masking +// This is more friendly to codegen (e.g. need generate less kernel) +// ... with the trade-off that may have more instruction in causal mode +template +struct SimplifiedGenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = i_y + x; // this could be larger than x_total, but it's fine + + return i_x < x_start || i_x >= x_end; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + // TODO: no need to check begin + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + + bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad + bool bottom_left_edge = i_y_end > (i_x + y); + // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask @@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t x_total, bool is_top_left = true) { - index_t x = 0, y = 0; + // TODO: below should all use sgpr arithmetic + index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1; + index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1; - if(is_top_left) - { - if(left_size < 0) - left_size = y_total - 1; - if(right_size < 0) - right_size = x_total - 1; + left_size = left_size < 0 ? left_size_tmp : left_size; + right_size = right_size < 0 ? right_size_tmp : right_size; - x = 1 + right_size; - y = left_size + 1; - } - else - { - if(left_size < 0) - left_size = x_total - 1; - if(right_size < 0) - right_size = y_total - 1; + index_t x_tmp = is_top_left ? 0 : x_total - y_total; + index_t y_tmp = is_top_left ? 0 : y_total - x_total; - x = x_total - y_total + 1 + right_size; - y = y_total - x_total + 1 + left_size; - } + index_t x = 1 + right_size + x_tmp; + index_t y = 1 + left_size + y_tmp; return ck_tile::make_tuple(y, x, y_total, x_total); } + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 98866805a0..a5f7d95d42 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -138,7 +138,9 @@ struct FmhaFwdKernel struct FmhaFwdMaskKargs { - ck_tile::index_t mask_y, mask_x; + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; }; struct FmhaFwdFP8Kargs @@ -217,8 +219,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, float descale_qk, float descale_sv) { @@ -262,8 +265,9 @@ struct FmhaFwdKernel } if constexpr(kHasMask) { - kargs.mask_y = mask_y; - kargs.mask_x = mask_x; + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) { @@ -306,8 +310,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, float descale_qk, float descale_sv) { @@ -349,8 +354,9 @@ struct FmhaFwdKernel } if constexpr(kHasMask) { - kargs.mask_y = mask_y; - kargs.mask_x = mask_x; + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) { @@ -639,7 +645,12 @@ struct FmhaFwdKernel FmhaMask mask = [&]() { if constexpr(kHasMask) - return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp new file mode 100644 index 0000000000..ae5a88df21 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaPipelineEnum +{ + QRKSVS = 0, + QRKSVS_ASYNC, + QRKSVS_FP8, + QSKSVS, +}; + +} // namespace ck_tile