From 70238cab87659f0f00d31a77561428cb99a2dce6 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 8 Aug 2025 16:23:04 +0000 Subject: [PATCH] Device implementation of explicit gemm for grouped conv bwd weight Based on batched gemm multiple D --- ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 756 ++++++++++++++++++ ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 5 + ...vice_grouped_conv_bwd_weight_explicit.hpp} | 10 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 8 +- ..._bwd_wei_exp_device_operation_instance.hpp | 24 +- ..._gemm_wmma_universal_km_kn_mn_instance.hpp | 111 +++ .../grouped_convolution_backward_weight.hpp | 155 +++- ...volution_backward_weight_explicit_wmma.inc | 459 +++++++++++ ...nvolution_backward_weight_explicit_xdl.inc | 72 +- .../grouped_convnd_bwd_weight/CMakeLists.txt | 59 +- ...16_bf16_bf16_exp_comp_default_instance.cpp | 67 ++ ...bf16_bf16_exp_comp_mnkpadding_instance.cpp | 67 ++ ..._bf16_bf16_exp_mem_v1_default_instance.cpp | 67 ++ ...16_bf16_exp_mem_v1_mnkpadding_instance.cpp | 69 ++ ..._bf16_bf16_exp_mem_v2_default_instance.cpp | 67 ++ ...16_bf16_exp_mem_v2_mnkpadding_instance.cpp | 69 ++ ...wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp | 71 ++ ...mma_bf16_bf16_bf16_exp_odd_mn_instance.cpp | 71 ++ ...wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp | 69 ++ ..._f16_f16_f16_exp_comp_default_instance.cpp | 67 ++ ...6_f16_f16_exp_comp_mnkpadding_instance.cpp | 67 ++ ...16_f16_f16_exp_mem_v1_default_instance.cpp | 67 ++ ...f16_f16_exp_mem_v1_mnkpadding_instance.cpp | 69 ++ ...16_f16_f16_exp_mem_v2_default_instance.cpp | 67 ++ ...f16_f16_exp_mem_v2_mnkpadding_instance.cpp | 69 ++ ...ht_wmma_f16_f16_f16_exp_odd_m_instance.cpp | 71 ++ ...t_wmma_f16_f16_f16_exp_odd_mn_instance.cpp | 71 ++ ...ht_wmma_f16_f16_f16_exp_odd_n_instance.cpp | 69 ++ ...6_bf16_bf16_exp_comp_default_instance.cpp} | 4 +- ...f16_bf16_exp_comp_mnkpadding_instance.cpp} | 4 +- ...bf16_bf16_exp_mem_v1_default_instance.cpp} | 4 +- ...6_bf16_exp_mem_v1_mnkpadding_instance.cpp} | 4 +- ...bf16_bf16_exp_mem_v2_default_instance.cpp} | 4 +- ...6_bf16_exp_mem_v2_mnkpadding_instance.cpp} | 4 +- ...xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp} | 4 +- ...dl_bf16_bf16_bf16_exp_odd_mn_instance.cpp} | 4 +- ...xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp} | 4 +- ...f16_f16_f16_exp_comp_default_instance.cpp} | 4 +- ..._f16_f16_exp_comp_mnkpadding_instance.cpp} | 4 +- ...6_f16_f16_exp_mem_v1_default_instance.cpp} | 4 +- ...16_f16_exp_mem_v1_mnkpadding_instance.cpp} | 4 +- ...6_f16_f16_exp_mem_v2_default_instance.cpp} | 4 +- ...16_f16_exp_mem_v2_mnkpadding_instance.cpp} | 4 +- ...ht_xdl_f16_f16_f16_exp_odd_m_instance.cpp} | 4 +- ...t_xdl_f16_f16_f16_exp_odd_mn_instance.cpp} | 4 +- ...ht_xdl_f16_f16_f16_exp_odd_n_instance.cpp} | 4 +- test/grouped_convnd_bwd_weight/CMakeLists.txt | 4 +- 47 files changed, 2820 insertions(+), 149 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp rename include/ck/tensor_operation/gpu/device/impl/{device_grouped_conv_bwd_weight_explicit_xdl.hpp => device_grouped_conv_bwd_weight_explicit.hpp} (98%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instance.cpp} (94%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..f9e3eb0eb3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,756 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_multi_d_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + static_for<0, GridwiseGemm::NumATensor, 1>{}( + [&](auto i) { splitk_batch_offset.a_k_split_offset[i] += a_batch_offset; }); + + static_for<0, GridwiseGemm::NumBTensor, 1>{}( + [&](auto i) { splitk_batch_offset.b_k_split_offset[i] += b_batch_offset; }); + + splitk_batch_offset.c_reduce_offset += c_batch_offset; + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 + : public DeviceBatchedGemmV2MultiD +{ + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDataType_ = EDataType; + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch() = default; + ComputePtrOffsetOfStridedBatch( + index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideA_) * g_idx; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideB_) * g_idx; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset_; + + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; + }); + + return ds_offset_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideC_) * g_idx; + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideC_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument() = default; + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, + index_t KBatch_) + : GridwiseGemm::Argument{std::array{p_a_grid_}, + std::array{p_b_grid_}, + p_ds_grid_, + p_e_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + StrideDs_, + StrideE_, + KBatch_, + a_element_op_, + b_element_op_, + cde_element_op_, + false}, + Batch{Batch_}, + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_} + { + } + template + void SetEPointer(void* ptr) + { + this->p_e_grid = static_cast(ptr); + } + }; + + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + std::array size_as_buffers; + size_as_buffers[0] = arg_.Batch * + a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + + std::array size_bs_buffers; + size_bs_buffers[0] = arg_.Batch * + b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + std::array size_ds_buffers; + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + size_ds_buffers[i] = + ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + DsDataType> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg.Batch * arg_.M * arg_.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + const auto clear_workspace = [&]() { + if(arg.KBatch > 1) + hipGetErrorString( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = + launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch = 1) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + cde_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + cde_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< + void SetEPointer(void* ptr) + { + this->p_c_grid = static_cast(ptr); + } }; using Argument = ArgumentBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index be94da1e50..4872679754 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -32,7 +32,7 @@ template -struct DeviceGroupedConvBwdWeight_Explicit_Xdl +struct DeviceGroupedConvBwdWeight_Explicit : public DeviceGroupedConvBwdWeight(arg.p_workspace_); + explicit_gemm_args_with_workspace.template SetEPointer( + arg.p_workspace_); float avg_time = explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config); const index_t grid_size = @@ -494,7 +494,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeight_Explicit_Xdl" + str << "DeviceGroupedConvBwdWeight_Explicit" << "<" << DeviceGemmV3Op{}.GetTypeString() << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index d7d6652933..b24e7bc6ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -332,6 +332,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Problem { + __host__ Problem() = default; __host__ Problem(index_t M_, index_t N_, index_t K_, @@ -406,6 +407,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { + __host__ Argument() = default; __host__ Argument(std::array p_as_grid_, std::array p_bs_grid_, std::array p_ds_grid_, @@ -472,9 +474,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 DsGridPointer p_ds_grid; EDataType* p_e_grid; - const AElementwiseOperation a_element_op; - const BElementwiseOperation b_element_op; - const CDEElementwiseOperation cde_element_op; + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; diff --git a/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp index 8e2ee30430..8d0878d1a6 100644 --- a/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp @@ -7,7 +7,7 @@ #include #include "ck/utility/functional2.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp" namespace ck { namespace tensor_operation { @@ -32,17 +32,17 @@ void add_explicit_gemm_device_operation_instances( ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { using DeviceGemmOp = std::tuple_element_t; - using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl; + using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit; static_assert(std::is_base_of_v, "wrong! NewOpInstance should be derived from BaseOp"); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp new file mode 100644 index 0000000000..3907784b52 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = bhalf_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMPadding = GemmSpecialization::MPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_km_kn_mn_comp_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_universal_km_kn_mn_mem_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // Memory friendly + // clang-format on + >; + +template +using device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_universal_km_kn_mn_odd_n_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 5656555819..96451c36b5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -21,6 +21,7 @@ #endif #ifdef CK_USE_WMMA #include "grouped_convolution_backward_weight_wmma.inc" +#include "grouped_convolution_backward_weight_explicit_wmma.inc" #endif namespace ck { namespace tensor_operation { @@ -395,21 +396,24 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances); +#endif + +// 3D +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc index 8958e4c1ee..e70ff61f5b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc @@ -10,7 +10,7 @@ namespace instance { // 2D #ifdef CK_ENABLE_BF16 -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances( std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..8ad67dff8f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..1fe1be2386 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..3c09f63a37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..9efac544fb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..c8717ef0b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp new file mode 100644 index 0000000000..2a2e0b6929 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp new file mode 100644 index 0000000000..c07d91f607 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp new file mode 100644 index 0000000000..792466cac8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp new file mode 100644 index 0000000000..dff2c137d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..265c14f46b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..110430452f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..cbb1171e65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..320c74e5f6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..f8e4f454c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp new file mode 100644 index 0000000000..5114de20c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp new file mode 100644 index 0000000000..d719e15581 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp new file mode 100644 index 0000000000..03d04fb6c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp index 088f4b0ef7..cacf7ec5c7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector