diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index 11e2add132..a18f108e47 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index ee1ddc494d..b88f071a96 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); // The normal approach to batching would be to increase the grid size by just stretching out // the grid Z dimension (which is the outermost dimension), but this depends on lower level // functions not directly using the Z dimension for other calculations. As it turns out, k @@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index c64a1d504d..e8e3b69cb5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, // IsBPreShuffled + false, // ForceThreadTileTransfer + true>; // IsFusedKernel using ReduceTrait = ReduceTrait_; + PermuteB, + false, + false, + true>; // Welford 2nd part kernel template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index b64b72f4d4..317c4073df 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, + false, + true>; using ReduceTrait = ReduceTrait_))) { #endif - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; + using EpilogueType = + typename std::conditional::type; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; - GridwiseGemm::template Run::value, Number<0>, @@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 NPerBlock / ClusterLengthNPerBlock>{}; template - static auto - MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) - + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - const auto M = in_gemmm_gemmk_desc.GetLength(I0); - const auto K = in_gemmm_gemmk_desc.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(in_gemmm_gemmk_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return in_gemmm_gemmk_desc; } template - static auto - MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - const auto N = wei_gemmn_gemmk_desc.GetLength(I0); - const auto K = wei_gemmn_gemmk_desc.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(wei_gemmn_gemmk_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return wei_gemmn_gemmk_desc; } template @@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + // Force MN padding on the output tensor. This allows to use Gemm default or only K padding + // and remove some instructions in the hot loop (same approach used for gemm universal). if constexpr(CTranspose) { - constexpr auto matrix_padder_trans = - MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; - return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding_trans = + MatrixPadder{ + NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } else { - return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding = + MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } } @@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 BlkGemmPipelineVer, AComputeDataType, BComputeDataType, - false, // PermuteA - false, // PermuteB - false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + UseThreadTileTransfer>; // ForceThreadTileTransfer // TODO: Previously available template param DoElementwiseBeforeCShuffle! @@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 false, // PermuteB false, // PermuteA false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer) using GridwiseGemmCTranspose = std::conditional_t; @@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 I1>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; - using BGridDesc_BK0_N_BK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // Argument struct Argument : public BaseArgument @@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 ds_grid_desc_m_n_{}, e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - a_grid_desc_ak0_m_ak1_{ - MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, - b_grid_desc_bk0_n_bk1_{ - MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } { - const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmM = a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_n_k_.GetLength(I0); const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) : GridwiseGemmCTranspose::CalculateMBlock(GemmM); const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) @@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 is_same_v) { size_as_buffers[i] = - (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + (a_grid_desc_m_k_.GetElementSpaceSize() + (num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1) { - size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() + (eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } else { - size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * + size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() * eff_num_group * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 static_for<0, NumBTensor, 1>{}([&](auto i) { using BDataType_single = remove_cvref_t>; - size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group * + size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group * sizeof(BDataType_single) / GridwiseGemm::BPackedSize; }); @@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 void Print() const { - std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; - std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); const index_t num_workgroups_per_Conv_N = arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; @@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemmCTranspose, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_N_K, + DeviceOp::AGridDesc_M_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_M_K, + DeviceOp::BGridDesc_N_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } // check Gridwise GEMM - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); if constexpr(CTranspose) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index b7c0d89e0f..5ae9eaf8ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation cde_element_op) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[LDS_size]; const auto gemm_desc_ptr = @@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) gemm_desc_ptr[group_id].StrideE, 1); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; constexpr TailNumber TailNum = TailNumber::Full; if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 714d567020..39024d39e4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t group_count) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index b8dd5905aa..dd12cdca8c 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -59,6 +59,8 @@ struct EpilogueCShuffleBase 1, CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>; + __device__ static constexpr bool IsLDSNeeded() { return true; } + // *Caution Here repeat is shuffle repeat __device__ static constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp new file mode 100644 index 0000000000..859225a831 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp @@ -0,0 +1,145 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +struct EpilogueDirectStore +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + __device__ static constexpr bool IsLDSNeeded() { return false; } + + template + __device__ static void Run(CThreadBuf& c_thread_buf, + DsGridPointer, + EDataType* p_e_grid, + void*, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + BlockwiseGemmPipe:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // origin + const auto c_thread_mtx_on_block = + BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0); + + const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I0])); + + const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I1])); + + // E grid descriptor + const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(block_m_id), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_freeze_transform(block_n_id), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{})); + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + CDEElementwiseOperation, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 3, + NRepeat, // VectorSize + EGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(m_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3]), + cde_element_op}; + + c_thread_copy.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + e_grid_buf); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index cf471578ca..e47bb37a89 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -77,26 +77,79 @@ struct ABTransferWaveTiles static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + template + __host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t MNPad, + index_t sizeK, + index_t KPad, + index_t, + index_t) + { + if constexpr(PadMN && PadK) + { + // pad both MN and K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadMN && !PadK) + { + // pad MN, but not K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_pass_through_transform(sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(!PadMN && PadK) + { + // pad K, but not MN + return transform_tensor_descriptor( + base_desc, + make_tuple(make_pass_through_transform(sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad MN or K + return base_desc; + } + } + template __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, index_t sizeMN, - index_t, + index_t MNPad, index_t sizeK, - index_t, + index_t KPad, index_t, index_t) { - // Notes: padding is currently not supported - static_assert(!PadMN && !PadK, "padding is currently not supported"); + // Notes: padding is currently not supported with transpose + static_assert(!((PadMN || PadK) && ABDoTranspose), + "padding is currently not supported with transpose"); + + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; + + const auto base_desc_padded = + PadGridDescriptor(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( - base_desc, + base_desc_padded, make_tuple( make_unmerge_transform(make_tuple( - math::integer_divide_ceil(sizeMN, Number{}), Number{})), - make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number{}), - Number{}))), + math::integer_divide_ceil(MN_grid, Number{}), Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(K_grid, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -112,9 +165,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}))), @@ -127,8 +180,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -143,9 +196,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_unmerge_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), @@ -157,8 +210,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_freeze_transform(I0), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp new file mode 100644 index 0000000000..bfe5b7bd08 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -0,0 +1,275 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +template +struct ABTransferWaveTilesInterleave : ABTransferWaveTiles +{ + using Base = ABTransferWaveTiles; + + using Base::ABDoTranspose; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::MNKRow; + + using Base::GetBlockLaneIdx; + using Base::GetBlockStep; + using Base::GetGridLaneIdx; + using Base::GetWaveIdx; + using Base::PadGridDescriptor; + using typename Base::ThisThreadBlock; + + static constexpr auto I4 = Number<4>{}; + + static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet"); + + using Base::KRepeat_; + using Base::KWaves_; + using Base::MNRepeat_; + + static constexpr index_t MNWaves_Grid = MNWaves_Gemm; + static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm; + static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack); + static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma); + + template + __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t MNPad, + index_t sizeK, + index_t KPad, + index_t, + index_t) + { + const auto base_desc_padded = Base::template PadGridDescriptor( + base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); + + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; + + // Divide the base descriptor MN_K into tiles + const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( + base_desc_padded, + make_tuple(make_unmerge_transform(make_tuple( + math::integer_divide_ceil(MN_grid, Number{}), + Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(K_grid, Number{}), Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + // The distinction is needed to get the same global indices for both layouts + // Divide each tile in 2 16x8 subtile + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + // MNKRow = 0-1 + // LaneLocal = 0-15 + // VectorSize must be 8 + if constexpr(!ABDoTranspose) + { + const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{})); + + const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 = + transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_mnrepeat, + make_tuple(make_pass_through_transform(math::integer_divide_ceil( + MN_grid, Number{})), + make_pass_through_transform( + math::integer_divide_ceil(K_grid, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}))), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5>{})); + + // Freeze VectorSize to first element of the loading chunk (for convenience) + // Swap MNPerWmma and MNKRow for consistency with transpose descriptor + return transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<3>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + } + + __device__ static constexpr auto GetBlockDescriptor() + { + // LDS memory layouts: + // lanes within tiles stored contiguously in chunks of 8 elements + // tiles are then stored first in K dimension + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + const auto a_grid_desc_mraw_kraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + I1)); + }(); + + // Freeze VectorSize to first element of the chunk (for convenience) + return transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + + template + __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, + BlockDescriptor& block_descriptor, + ABElementwiseOperation& ab_element_op, + const index_t block_mn_id, + const index_t) + { + // Note: GlobalBufferNum is currently not used but it will be needed + // once we add other pipelines. It is currently needed only for + // consistency with the thread tiles approach + static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); + constexpr index_t NumABTensor = ABsDataType::Size(); + static_assert(NumABTensor == 1, "multiAB currently not supported"); + + using ABDataType = remove_cvref_t>; + + const auto wave_idx = GetWaveIdx(); + index_t wave_idK = wave_idx[I1]; + index_t wave_idMN = wave_idx[I0]; + + const auto grid_lane_id = Base::template GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + + const auto block_lane_id = GetBlockLaneIdx(); + index_t lane_group_block = block_lane_id[I0]; + index_t lane_local_id_block = block_lane_id[I1]; + + constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; + return ThreadGroupTransferGlobal, + Sequence, + Sequence, + ABK1Value, + ABDoTranspose>( + grid_descriptor[I0], + block_descriptor, + make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_grid, + lane_local_id_grid), + make_multi_index(wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_block, + lane_local_id_block), + ab_element_op); + } + + __device__ static constexpr auto GetBlockStep() + { + // Grid descriptor step (MoveSrcSliceWindow) + return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 9f7fd47083..b46afda8b7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -177,7 +177,8 @@ template + bool ForceThreadTileTransfer = false, + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3 : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer> + ForceThreadTileTransfer, + IsFusedKernel> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer>; + ForceThreadTileTransfer, + IsFusedKernel>; using Base::I0; using Base::I1; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 79549d6385..ec7710d066 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" @@ -24,6 +25,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" @@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -167,7 +175,8 @@ template // only needed for convolution (limitation) + bool ForceThreadTileTransfer = false, // only needed for convolution (limitation) + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3_base { @@ -182,6 +191,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); using LDSTypeA = typename std::conditional<(NumATensor > 1), @@ -232,30 +242,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return 1; }(); + static constexpr index_t WaveSize = + WmmaSelector::selected_wmma + .wave_size; + // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default - // - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation) - // AK1Value == 8 is not really a limitation but a requirement for the method so - // it will stay + // - GemmSpecialization Default with transpose #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; static constexpr bool IsBWaveTransferApplicable = !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + + static constexpr bool IsWaveTileInterleavedFitting = + (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); + + // We need to investigate if it makes sense to remove cshuffle for smaller types + // Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at + // least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line) + static constexpr bool UseDirectStore = is_same_v && + sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && + NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) && + !IsFusedKernel && IsWaveTileInterleavedFitting; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; + static constexpr bool UseDirectStore = false; #endif - static constexpr index_t WaveSize = - WmmaSelector::selected_wmma - .wave_size; static constexpr bool UseBlockPaddingA = ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; using ATransfer = typename std::conditional< @@ -293,7 +317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr bool UseBlockPaddingB = BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; - using BTransfer = typename std::conditional< IsBPreShuffled, ABTransferThreadTilesPreShuffle, typename std::conditional< IsBWaveTransferApplicable, - ABTransferWaveTiles, + typename std::conditional< + UseDirectStore, + ABTransferWaveTilesInterleave, + ABTransferWaveTiles>::type, ABTransferThreadTiles{}); } + template + __device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc) + { + const auto M = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto AK0 = K / AK1Value; + + constexpr bool padM = false; + constexpr bool padK = false; + return ATransfer::template MakeGridDescriptor(base_desc, M, M, K, K, 0, AK0); + } + __host__ __device__ static auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, @@ -516,6 +565,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base Number{}); } + template + __device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc) + { + const auto N = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto BK0 = K / BK1Value; + + constexpr bool padN = false; + constexpr bool padK = false; + return BTransfer::template MakeGridDescriptor(base_desc, N, N, K, K, 0, BK0); + } + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() { constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); @@ -594,8 +656,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base #endif } - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MakeDsGridPointer() { return generate_tuple( @@ -679,6 +739,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base ThisThreadBlock, BlockwiseGemmPipe>; + using EpilogueDirectStore = EpilogueDirectStore; + using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< DsDataType, EDataType, @@ -1000,18 +1068,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base max_lds_align) : 0; - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - EpilogueType:: - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + if constexpr(EpilogueType::IsLDSNeeded()) + { + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + EpilogueType:: + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + - b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize; + } } template @@ -1148,7 +1224,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base num_k_block_main_loop, num_k_block_per_scale); - // shuffle C and write out + // Epilogue: + // - CShuffle / direct store + // - Multiple Ds + // - Fused operations epilogue_args.template Run( c_thread_buf, p_ds_grid, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp new file mode 100644 index 0000000000..2529c55e31 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index d38aa66ece..08e2092c50 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -797,6 +797,8 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -117,6 +131,20 @@ void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instanc PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 380c83fa92..4b8f1d1a16 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -125,6 +125,8 @@ set(GROUPED_CONV2D_FWD wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..cbb4eae126 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + GemmMNKPadding, + BF16>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + GemmDefault, + BF16>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..099804294d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + GemmMNKPadding, + F16>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + GemmDefault, + F16>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck