From 12e0f0b1ba0e425411d58c5aa0df43ce93ca2f4a Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Fri, 19 Dec 2025 05:55:50 +0100 Subject: [PATCH] Added large tensor support for grouped conv fwd wmma (#3437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances. * Fixed typos * Updated the set of tests for FP16 * Updated the set of tests for FP16 * Fix typo * Moved f16xi4 test under the correct data layout group * example for gemm_universal_bf16 * Adding examples for gemm_wmma instances * Added the missing parameters * Fixed review comments and added executable to cmakeLists * Fixing clang format * Fixing build erros * Fixed compilation failure. * Modified some code as per gemm_universal_examples * Fixed the gemm specialization error * Fixed the build errors. * Fix strides of a/b_thread_desc The descriptors are larger than needed (even though the compiler don't alloc registers for unused values). * Load in M/NRepeat dims with thread copy's slice instead of a loop * Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation * Implement Intrawave and Interwave variants of pipeline v1 * Add instances for Interwave and Intrawave v1 * Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0 * Remove instances that are too slow (mostly because of register spilling) * Add a workaround for fp8/bf8->f32 packed conversion issue * Add instances for Interwave and Intrawave v1 * Enable profiling of mixed precision with f8 and int4 on WMMA * Fix segfault in profiler when B is pk_i4_t b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds. * Remove instances that are too slow (mostly because of register spilling) * Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations * Add test case for bf16_i4 * Add missing Regular tests * Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS They take more than 30 seconds * Fix a bug that fp16_i4 validation passes only with PermuteB A permutation required by conversion from pk_i4_t to half_t does not depend on PermuteB, they can be used independently. * Use PermuteB with f16_i4 in most instances (as xdl) Some instances use PermuteB = false for checking correctness. See also the previous commit. * Fix cache flushing for pk_i4 * Add mixed precision examples * Disable all tests and instances with f8 on gfx11 Even though f8_f16 and f16_f8 don't require f8 WMMA instructions, gfx11 still lacks hardware instructions for fast f8->f32 conversion. * Add FP16 KM_NK and KM_KN test suites for XDL These tests were added to common .inc for better testing of WMMA instances * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * removed unnecessary ck parts from compilation * initial gemm_add_multiply instance implementations * fixed profiler help message for gemm_add_multiply * improved multiply_add profiler layout help * fixed template arguments for test instances * added test for gemm_add_multiply * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * switched to splitK interface * log print added to splitk benchmarks * revert main cmake comments * newline change reverted * added add_fastgelu instances * revert unintended change in xdl add_fastgelu * created gemm_add_add_fastgelu instances * created fastegelu instances * added tests for all splitk fastgelus * Added tests. * multiply_add instances created * updates to add_multiply splitk instances * splitk xdl test fixes * added wmma multiply_multiply instances * fixed ONLY_XDL_AND_WMMA_KERNELS tag * Added gemm_add examples for wmma v1 and v3 * fixed / workarounded i8 instances * Modified the v3 code to added one fp16 bxdl instance. * added bf16 xdl instance. * adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan * add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan * This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan * temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan * added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan * Fixing build errors * Added instances for v3 * Adding instances and executables * Code update of template parameters modified. * Renamed file. * Added tests. * resolved error tests. * Fixing build errors * Updated comments * removed the changes as per the MR review comment. * Updated tests. * fp8 instances - not tested * Restored the Cmake file that was reverted by mistake during rebase. * fixed wmma_op test * Updated comments. * Updated the template parameter description * fixed rdna4 instances * fixed back compatibility on gfx11 * cleanups * fix ckProfiler * one more cmake fix * added fp8 instances * Updated tests to ad BF16 instances as per review comment * Added include file and cleaned up(as per review comment) * Updated and optimized the example code for all types. * Fixed clang format * Resolve "Implement `device_gemm_bilinear` for RDNA4" * test generalization to handle FP16 shuffle better * added missing changes * Added bf16 wmma instance for add_relu * Added f16 wmma instance and corrected bf16 instance errors. * Added instances to Cmake * Modified the template parameters to make the instances work. * Fixed typo in profiler * Added v3 instances for gemm_add_relu * addressed core review comments * Added test for gemm_add_relu wmma instance * Cleaned up the code. * Added examples for gemm_add_relu * Fixing typo to resolve build errors. * Fixes applied to fix the precision loss. * fix billinear test after merge * Removed the old wmma instances. * Added wrapper and renamed the wmma_v3 instances * Updated copyrights and added wrappers. * Fixes applied according to review comments * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter * Removed the old wmma instances. * Updated wrapper for the v3 instances * removed the old wmma examples * Renamed the v3 instances * Deleted the gtest file added by mistake. * Updated thge profiler with wrapper * Fixed test errors. * Fixed the review comments * Fixed the if condition MACROS. * REVERTED THE PROFILER CHANGES * Revert "REVERTED THE PROFILER CHANGES" This reverts commit 8ba7f2a5cb92232a2160d25e89d194747d2c173e. * Revert "Fixed test errors." This reverts commit a9a0071745937ece49a18a2ec5ae1463d26a9a2c. * Revert "Updated thge profiler with wrapper" This reverts commit 2ba5152e85a4b046562ca19ba437aab5ec0ad2ab. * Added missing wrapper instances * Updated copyrights. * Fixed typo. * Fixed copyrights. * Updated copyrights. * updated copyrights. * comments on the atomics workaround * fixed cmake comment * Fix bug from merge * clang-format-18 * Fix compilation error * multi_abd wmma support: - Add multiple A and B support to multiple D implementation (gridwise level) - Add multi_abd GEMM (device level) - Add instances (xdl parity) - Add tests (both xdl and wmma) - Add examples - Add ckProfiler support (both xdl and wmma) * Fix bug in device print function * Fix unused template parameter * Add support for fwd conv in gridwise implementation. Identical to run function for bwd data. * Initial device implementation for grouped conv fwd multiABD wmma cshuffleV3. Functional but needs some fixups and extra features in the future. * Make relevant profilers print the number of valid instances to aid testing. * Add instances for all vanilla 2D and 3D flavors for f16 and bf16, only one instance per instance list to save compile time for now. Also added incomplete set of comp instances and bias_clamp for f16 2D, just to make sure the multiple-D aspects of the device implementation are working. * Reset output buffer after each run in profile_grouped_conv_fwd_impl(). * Disable sharding for the new instances for now, has tendency to lead to linker errors on repeat builds. * Add CTranspose optimization for NCHW cases just like in xdl cshuffle non-v3 device implementation. * Add instances for all 8-bit 3D vanilla grouped conv fwd types, including mixed types but with the exception of deprecated f16 comp fp8. Adapt test so we can test 8-bit and mixed types. * Add int8 instances for 2D vanilla grouped conv fwd all layouts. * Implement merged groups in device impl and add instances for merged groups 3D vanilla conv fwd * Add merged groups instances for all 2D vanilla grouped conv fwd types and layouts. * Implement multi-AB support for grouped conv fwd and add example. * Add 1D instances * Add D layout tests to IsSupportedArgument() * Add comp and mem instances for all vanilla 2D grouped conv fwd types. Skipping "x2" and "part2" instance lists, can be added later without special names if necessary. * Add comp and mem instances for vanilla 3D grouped conv fwd. Skipped 2x and part2 instances, can be added later in the same instance lists. * Add some more tests for vanilla grouped conv fwd * Add 2D bias clamp instances and tests * Add 3D bias clamp instances and tests * Add 2D and 3D clamp instances and tests * Unify problem sizes across vanilla and clamp flavor tests * Clean up device implementation: remove old todos, remove unnecessary comments and print statements, tweak description, wrap all prints in env check. * Implement rotating memory and flush cache. Requires ad-hoc buffer size calculations. * Remove wmma fp8 and bf8 instances when not targetting gfx12 * Add newer instances to DEVICE_INSTANCES so the main ckProfiler can build * Remove old years for newly created files. * No need to time kernels for now. * Fixup comments * Pass struct args to Gridwise Run() function by reference. * Don't use workspace memory in the case where A needs explicit transposition but B does not. * Move calculation of rotating memory buffer sizes to Argument member functions. * After the convolution to gemm transformation, the resulting 2D tensor descriptors are not necessarily RowMajor or ColumnMajor, so things should not rely on this distinction. Therefore, pass all RowMajor to the Gridwise and use a special version of CheckValidity that does not rely on 2D tensor layouts. * Unify xdl and wmma example code for grouped conv fwd scaleadd ab * Go back to passing RCR 2D tensor layouts to gridwise gemm, and use CRC for the CTranspose case. Also remove the special convolution version of checkValidity(). It seems like no matter what 2D tensor layouts you pass to the gridwise gemm, and no matter if you are using extraMN, and no matter if you are using the convolution version of checkvalidity, the results of all tests are the same. * Add wmma scaleadd ab instances to the device factory and add a completely new scaleadd_ab gtest test for wmma cshufflev3 and xdl. Currently there is no profiler for scaleadd_ab so I made my own inside the test. Furthermore for XDL only the (NDHWGC, GKZYXC, NDHWGK) layout combination existed in the instance factory so that is the only one I added for wmma cshufflev3 and the gtest test as well. Another layout is tested in example 62, for xdl and wmma cshufflev3. * Add support for V3 pipeline (tested). To be able to support num_loop < 3 we need the fixes from the batched gemm gemm MR which was already merged upstream, so just need to rebase or merge. * Small post-merge fixup, everything seems to work. * Do not build or run Xdl operations with Wmma backend for now. Will be reverted before upstreaming. * Extend scaleadd_ab instance lists * Extend merged groups instance lists, including adaptations of xdl "2x" instances. * Extend "comp" instance lists, including "2x" and "part2" instances. 2x instances disabled for now since they do not compile. * Extend "mem" instance lists. * Extend regular instance lists. * Fixup comments and ignored kernel arg name * Properly use the splitN offsets for D tensors in the gridwise Run() function. Was necessary to pass the bias_clamp_large_cases test. * Make sure all strides in ComputePtrOffset are at least value initialized to avoid undefined strides. Not convinced this struct is properly initialized in other code / future code. * Re-enable sharding for wmma cshufflev3 instances * Post merge fix to vanilla test * Optionally allow num_k_loop <= PrefetchStages in gridwise CheckValidity. Use this for grouped conv fwd but not in general. * Remove spurious ck_tile changes that were presumably introduced somewhere in the repeated merging from develop. * Post-merge fixes. Make sure the new gridwise gemm wmma v3 common Run function can be used. Remove splitK, and forceThreadTileTransfer for now. Also add CShuffle epilogue argument. * Disable FP8 / BF8 testing on CDNA1/2, it doesn't work anymore and needs to be either fixed or removed. * Re-enable old wmma instances * Re-enable Linqun's Xdl Wmma instances * Small post-merge fixes * Fix copyright headers * Remove commented code snippet in gridwise Co-authored-by: Bartłomiej Kocot * Limit the explicit cast added in threadwise_tensor_slice_transfer_v7r3 to only be used for f8, just in case it hurts performance. * Adding tuned instace list for groupoed conv fwd (#3288) Following flavors are updated with tuned instance list: - grouped_conv2d_fwd - grouped_conv2d_fwd_bias_clamp - grouped_conv2d_fwd_clamp - grouped_conv3d_fwd - grouped_conv3d_fwd_bias_clamp - grouped_conv3d_fwd_clamp - grouped_conv3d_fwd_scaleadd_ab Re-factored instance selection: - removed all the unnecessary instance tuples (comp/mem/16x16/generic) - removed all unnecessary layouts and data types * Do not use std::remove_cvref_t, does not exist in C++17, use custom one. * Splitting grouped conv fwd instances (#3449) * Disable unnecessary and failing tests related to experimental CK builder * Disable unnecessary ck builder experimental tests fully * Added large tensor support for grouped conv fwd wmma --------- Co-authored-by: Anca Hamuraru Co-authored-by: apoorva Co-authored-by: Anton Gorenko Co-authored-by: Zoltan Lakatos Co-authored-by: Cenxuan Co-authored-by: Robin Voetter Co-authored-by: Enrico Degregori Co-authored-by: Kiefer van Teutem Co-authored-by: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com> Co-authored-by: Bartłomiej Kocot Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 7795e73b47a34a25b48a14f3e4e0e6d681fcbde5] --- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 1456 +++++++++++++++++ ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 2 +- ..._wmma_cshufflev3_large_tensor_instance.hpp | 83 + .../gpu/grouped_convolution_forward.hpp | 17 +- ..._convolution_forward_wmma_large_tensor.inc | 78 + .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 2 + ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 38 + ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 38 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 2 + ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 38 + ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 38 + test/grouped_convnd_fwd/CMakeLists.txt | 8 +- ...> test_grouped_convnd_fwd_large_cases.cpp} | 2 +- 13 files changed, 1788 insertions(+), 14 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_large_tensor.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd_large_cases_xdl.cpp => test_grouped_convnd_fwd_large_cases.cpp} (100%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp new file mode 100644 index 0000000000..08d0f296f0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -0,0 +1,1456 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_grouped_gemm_wmma_cshuffle_v3( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if defined(__gfx11__) || defined(__gfx12__) + using Epilogue = typename GridwiseGemm::EpilogueCShuffle; + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + const auto& gemm_arg = gemm_desc_kernel_args[group_id]; + const index_t block_x = block_id_x - gemm_arg.BlockStart_; + + typename GridwiseGemm::AsGridPointer p_as_grid_; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(gemm_arg.a_ptrs_[i]) + a_group_offset + a_n_offset; + }); + + typename GridwiseGemm::BsGridPointer p_bs_grid_; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = + static_cast(gemm_arg.b_ptrs_[i]) + b_group_offset + b_n_offset; + }); + + typename GridwiseGemm::DsGridPointer p_ds_grid_; + static_for<0, GemmArgs::NumDTensor, 1>{}([&](auto i) { + using DDataType_ = + remove_cvref_t>; + p_ds_grid_(i) = static_cast(gemm_arg.ds_ptrs_[i]) + ds_group_offset[i] + + ds_n_offset[i]; + }); + + const auto as_grid_desc_ak0_m_ak1 = generate_tuple([&](auto) { return gemm_arg.a_grid_desc_; }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple([&](auto) { return gemm_arg.b_grid_desc_; }, + Number{}); + + const auto& ds_grid_desc = gemm_arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_; + const auto& e_grid_desc = gemm_arg.e_grid_desc_mblock_mperblock_nblock_nperblock_; + + const auto block_2_ctile_map = + typename GridwiseGemm::Block2CTileMap{gemm_arg.M_, gemm_arg.N_, 4}; + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_x)); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc.GetLength(Number<0>{}), e_grid_desc.GetLength(Number<2>{})))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[Number<0>{}]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[Number<1>{}]); + + using AScale = typename GridwiseGemm::BlockwiseGemmPipe::Empty; + auto a_scale_struct = AScale{}; + + using BScale = typename GridwiseGemm::BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GridwiseGemm::GetKBlockPerScale(); + + auto epilogue_args = Epilogue{}; + + GridwiseGemm::Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_arg.e_ptr_ + e_group_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc, + e_grid_desc, + gemm_arg.a_element_op_, + gemm_arg.b_element_op_, + gemm_arg.cde_element_op_, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args); +#else + ignore = gemm_desc_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t MaxGemmsNum = 32; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm; + + using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto MakeAGridDescriptor(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeBGridDescriptor(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto + MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + static auto CastDsPointers(const std::array& p_ds) + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(p_ds[i]); + }, + Number{}); + } + + using DsPointer = decltype(CastDsPointers(std::array{})); + + using GemmAsDataType = Tuple; + using GemmBsDataType = Tuple; + using GemmDsDataType = DsDataType; + + using CDEBlockTransferScalarPerVectors = + typename uniform_sequence_gen::type; + // desc for problem definition + constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; + using AGridDesc = decltype(MakeAGridDescriptor(dummy_conv_to_gemm_transformer)); + using BGridDesc = decltype(MakeBGridDescriptor(dummy_conv_to_gemm_transformer)); + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + static auto + GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, + const ADataType* a_grid_ptr_base, + DsPointer ds_grid_ptr_base, + EDataType* c_grid_ptr_base) + { + // Max number of splits + // We need to use it to avoid infinity loop + constexpr index_t max_split_numbers = MaxGemmsNum / 2; + // Arrays to store transformers with smaller descs than 2GB + Array conv_to_gemm_transformers_arr; + Array a_grid_ptrs_arr; + Array ds_grid_ptrs_arr; + Array c_grid_ptrs_arr; + // Queue for splitting + std::queue conv_to_gemm_transformers_queue( + {conv_to_gemm_transformer_base}); + std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue ds_grid_ptrs_queue({ds_grid_ptr_base}); + std::queue c_grid_ptrs_queue({c_grid_ptr_base}); + + index_t gemms_number = 0; + index_t split_numbers = 0; + // Algorithm: + // While queue is not empty: + // 1. Get transformer from queue. + // 2. If descs are smaller than 2GB push to result array. + // 3. If descs are bigger than 2GB split into left and right transformer. + while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers && + gemms_number < MaxGemmsNum) + { + // Get transformer from the queue + const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); + const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front(); + EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); + + // Check if convolution not exceed 2GB + if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB()) + { + // If yes, push into result array + conv_to_gemm_transformers_arr(gemms_number) = + ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + gemms_number++; + } + else + { + // If no, split into left and right convolutions + ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part; + const ADataType* a_grid_right_ptr; + DsPointer ds_grid_right_ptr; + EDataType* c_grid_right_ptr; + + ck::tie(conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part, + a_grid_right_ptr, + ds_grid_right_ptr, + c_grid_right_ptr) = + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr); + + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); + // Left offsets remain the same + a_grid_ptrs_queue.push(a_grid_ptr); + a_grid_ptrs_queue.push(a_grid_right_ptr); + ds_grid_ptrs_queue.push(ds_grid_ptr); + ds_grid_ptrs_queue.push(ds_grid_right_ptr); + c_grid_ptrs_queue.push(c_grid_ptr); + c_grid_ptrs_queue.push(c_grid_right_ptr); + split_numbers++; + } + // Remove from the queue + conv_to_gemm_transformers_queue.pop(); + a_grid_ptrs_queue.pop(); + ds_grid_ptrs_queue.pop(); + c_grid_ptrs_queue.pop(); + } + + const bool is_split_valid = conv_to_gemm_transformers_queue.empty(); + + return ck::make_tuple(conv_to_gemm_transformers_arr, + a_grid_ptrs_arr, + ds_grid_ptrs_arr, + c_grid_ptrs_arr, + gemms_number, + is_split_valid); + } + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + DsLayout, + tensor_layout::gemm::RowMajor, + GemmAsDataType, + GemmBsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + K1, + K1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeDataType, + BComputeDataType, + false, + false, + false, + true>; + + // desc for blockwise copy + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}, 1, 1))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}, 1, 1))>; + + // Structure for each gemm(conv) + struct GemmArgs + { + using AsDataType = GemmAsDataType; + using BsDataType = GemmBsDataType; + using DsDataTypeTuple = GemmDsDataType; + + static constexpr index_t NumATensor = GridwiseGemm::NumATensor; + static constexpr index_t NumBTensor = GridwiseGemm::NumBTensor; + static constexpr index_t NumDTensor = DeviceOp::NumDTensor; + + std::array a_ptrs_{}; + std::array b_ptrs_{}; + std::array ds_ptrs_{}; + EDataType* e_ptr_ = nullptr; + + AElementwiseOperation a_element_op_{}; + BElementwiseOperation b_element_op_{}; + CDEElementwiseOperation cde_element_op_{}; + + index_t M_ = 0; + index_t N_ = 0; + + AGridDesc a_grid_desc_{}; + BGridDesc b_grid_desc_{}; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_{}; + + ck::index_t BlockStart_ = 0; + ck::index_t BlockEnd_ = 0; + }; + + // Argument + struct Argument : public BaseArgument + { + template + void init_gemm_args(const std::array& p_as_grid, + const std::array& p_bs_grid, + const std::array& p_ds_grid, + EDataType* p_e_grid, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDesc_M_N_& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + index_t gemm_m, + index_t gemm_n, + index_t gemm_k, + index_t BlockStart, + index_t BlockEnd) + { + std::array stride_as{}; + std::array stride_bs{}; + std::array stride_ds{}; + + auto gemm_arg = typename GridwiseGemm::Argument{p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + gemm_m, + gemm_n, + gemm_k, + stride_as, + stride_bs, + stride_ds, + index_t{0}, + index_t{1}, + a_element_op_, + b_element_op_, + cde_element_op_}; + + if(GridwiseGemm::CheckValidity(gemm_arg, true)) + { + const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m); + const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n); + + GemmArgs new_args{}; + new_args.a_ptrs_ = p_as_grid; + new_args.b_ptrs_ = p_bs_grid; + new_args.ds_ptrs_ = p_ds_grid; + new_args.e_ptr_ = p_e_grid; + + new_args.a_element_op_ = a_element_op_; + new_args.b_element_op_ = b_element_op_; + new_args.cde_element_op_ = cde_element_op_; + + new_args.M_ = gemm_m; + new_args.N_ = gemm_n; + + new_args.a_grid_desc_ = a_grid_desc; + new_args.b_grid_desc_ = b_grid_desc; + new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, m_block, n_block); + new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, m_block, n_block); + + new_args.BlockStart_ = BlockStart; + new_args.BlockEnd_ = BlockEnd; + + gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args; + + valid_gemms_count_++; + } + } + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : num_group_{static_cast(a_g_n_c_wis_lengths[0])}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array ds_grid_ptrs; + Array c_grid_ptrs; + + DsPointer p_ds_casted = CastDsPointers(p_ds); + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + ds_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + p_ds_casted, + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) + { + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) + { + const AGridDesc a_grid_desc{ + DeviceOp::MakeAGridDescriptor(conv_to_gemm_transformer_arr[i])}; + const BGridDesc b_grid_desc{ + DeviceOp::MakeBGridDescriptor(conv_to_gemm_transformer_arr[i])}; + const EGridDesc_M_N e_grid_desc_m_n{DeviceOp::MakeEGridDescriptor_M_N( + conv_to_gemm_transformer_arr[i])}; + + const auto ds_grid_desc_m_n = + DeviceOp::MakeDsGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); + + const index_t GemmM = e_grid_desc_m_n.GetLength(I0); + const index_t GemmN = e_grid_desc_m_n.GetLength(I1); + const index_t GemmK = [&]() { + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + }(); + + std::array p_as_grid{}; + p_as_grid[0] = static_cast(a_grid_ptrs[i]); + + std::array p_bs_grid{}; + p_bs_grid[0] = static_cast(static_cast(p_b)); + + std::array p_ds_grid{}; + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto d) { + p_ds_grid[d.value] = static_cast(ds_grid_ptrs[i].At(d)); + }); + } + + const index_t grid_size_grp = + GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + init_gemm_args(p_as_grid, + p_bs_grid, + p_ds_grid, + c_grid_ptrs[i], + a_grid_desc, + b_grid_desc, + ds_grid_desc_m_n, + e_grid_desc_m_n, + GemmM, + GemmN, + GemmK, + BlockStart, + BlockEnd); + } + + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); + } + + void Print() const + { + std::cout << "===== Convolution summary =====" << std::endl; + std::cout << " num_group=" << num_group_ + << ", conv_N_total=" << a_g_n_c_wis_lengths_[I1] + << ", conv_N_per_block=" << conv_N_per_block_ << std::endl; + std::cout << " gemms_count=" << gemms_count_ + << ", valid_gemms_count=" << valid_gemms_count_ + << ", is_split_valid=" << std::boolalpha << is_split_valid_ + << std::noboolalpha << ", grid_size=" << grid_size_ << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << " Ds[" << i.value + << "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i) + << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i) + << std::endl; + }); + + std::cout << "===== GEMM splits =====" << std::endl; + for(index_t i = 0; i < valid_gemms_count_; ++i) + { + const auto& gemm = gemm_desc_kernel_args_[i]; + + const auto M = gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I0) * + gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I1); + const auto N = gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I2) * + gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3); + + const auto K = [&]() { + return gemm.a_grid_desc_.GetLength(I0) * gemm.a_grid_desc_.GetLength(I2); + }(); + + std::cout << " gemm[" << i << "] block_range=[" << gemm.BlockStart_ << ", " + << gemm.BlockEnd_ << ") (M,N,K)=(" << M << ", " << N << ", " << K + << ") grid_span=" << (gemm.BlockEnd_ - gemm.BlockStart_) << std::endl; + std::cout << " A descriptor: " << gemm.a_grid_desc_ << std::endl; + std::cout << " B descriptor: " << gemm.b_grid_desc_ << std::endl; + std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: " + << gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto d_idx) { + std::cout << " D" << d_idx.value << " descriptor: " + << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx) + << std::endl; + }); + } + } + + index_t num_group_; + index_t conv_N_per_block_; + + Array gemm_desc_kernel_args_; + + index_t grid_size_; + index_t gemms_count_; + index_t valid_gemms_count_; + + bool is_split_valid_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; + + const auto K = [&]() { + return arg.gemm_desc_kernel_args_[I0].a_grid_desc_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_.GetLength(I2); + }(); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_fwd_grouped_gemm_wmma_cshuffle_v3< + Gridwise, + MaxGemmsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + InMemoryDataOperationEnum::Set>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; + + if(Gridwise::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + return RunImp(arg, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + bool ds_valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + for(int d = 0; d < NDimSpatial + I3; d++) + { + if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d]) + { + ds_valid = false; + } + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + ds_valid = false; + } + } + + using DDataType = remove_cvref_t>; + static_assert(is_same_v); + }); + + if(!ds_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Ds tensors must have the same dimensions as E tensor!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + + // Check if all descs are valid + if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_ && + arg.valid_gemms_count_ > 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "GEMM splits are not valid!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Incorrect accumulator data type!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "WMMA large tensor not supported on this device!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input parameters are not valid for " + "Filter1x1Stride1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input parameters are not valid for " + "Filter1x1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Parameters for A Layout incorrect!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Parameters for B Layout incorrect!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + const index_t Kds = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(Kds % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Parameters for D tensor Layout incorrect!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + valid = false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported D Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Parameters for E Layout incorrect!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_lengths_i64; + std::array a_strides_i64; + std::array b_lengths_i64; + std::array b_strides_i64; + std::array, NumDTensor> ds_lengths_i64; + std::array, NumDTensor> ds_strides_i64; + std::array e_lengths_i64; + std::array e_strides_i64; + std::array conv_strides_i64; + std::array conv_dilations_i64; + std::array left_pads_i64; + std::array right_pads_i64; + + array_convert(a_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_strides_i64, a_g_n_c_wis_strides); + array_convert(b_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; ++d) + { + array_convert(ds_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_strides_i64, conv_filter_strides); + array_convert(conv_dilations_i64, conv_filter_dilations); + array_convert(left_pads_i64, input_left_pads); + array_convert(right_pads_i64, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_lengths_i64, + a_strides_i64, + b_lengths_i64, + b_strides_i64, + ds_lengths_i64, + ds_strides_i64, + e_lengths_i64, + e_strides_i64, + conv_strides_i64, + conv_dilations_i64, + left_pads_i64, + right_pads_i64, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_long{}; + std::array a_g_n_c_wis_strides_long{}; + std::array b_g_k_c_xs_lengths_long{}; + std::array b_g_k_c_xs_strides_long{}; + std::array, NumDTensor> + ds_g_n_k_wos_lengths_long{}; + std::array, NumDTensor> + ds_g_n_k_wos_strides_long{}; + std::array e_g_n_k_wos_lengths_long{}; + std::array e_g_n_k_wos_strides_long{}; + std::array conv_filter_strides_long{}; + std::array conv_filter_dilations_long{}; + std::array input_left_pads_long{}; + std::array input_right_pads_long{}; + + array_convert(a_g_n_c_wis_lengths_long, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_long, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_long, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_long, b_g_k_c_xs_strides); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + array_convert(ds_g_n_k_wos_lengths_long[i], ds_g_n_k_wos_lengths[i]); + array_convert(ds_g_n_k_wos_strides_long[i], ds_g_n_k_wos_strides[i]); + }); + + array_convert(e_g_n_k_wos_lengths_long, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_long, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_long, conv_filter_strides); + array_convert(conv_filter_dilations_long, conv_filter_dilations); + array_convert(input_left_pads_long, input_left_pads); + array_convert(input_right_pads_long, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_long, + a_g_n_c_wis_strides_long, + b_g_k_c_xs_lengths_long, + b_g_k_c_xs_strides_long, + ds_g_n_k_wos_lengths_long, + ds_g_n_k_wos_strides_long, + e_g_n_k_wos_lengths_long, + e_g_n_k_wos_strides_long, + conv_filter_strides_long, + conv_filter_dilations_long, + input_left_pads_long, + input_right_pads_long, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + std::stringstream ss; + ss << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor" << "<" << BlockSize + << ", " << MPerBlock << ", " << NPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << MPerWmma + << ", " << NPerWmma << ", " << MRepeat << ", " << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " + << CDEShuffleBlockTransferScalarPerVector_NPerBlock << ", " << CShuffleMRepeatPerShuffle + << ", " << CShuffleNRepeatPerShuffle << ">"; + return ss.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index b21af2abb0..7c121f1482 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor Array a_grid_ptrs_arr; Array ds_grid_ptrs_arr; Array c_grid_ptrs_arr; - // Queue for spliting + // Queue for splitting std::queue conv_to_gemm_transformers_queue( {conv_to_gemm_transformer_base}); std::queue a_grid_ptrs_queue({a_grid_ptr_base}); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp new file mode 100644 index 0000000000..c3769fbfd0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1> + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_large_tensor_bf16_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index e869a08ab7..d38aa66ece 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -31,6 +31,7 @@ #include "grouped_convolution_forward_wmma_cshufflev3.inc" #include "grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc" #include "grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc" +#include "grouped_convolution_forward_wmma_large_tensor.inc" #endif namespace ck { @@ -794,8 +795,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 30063d268e..380c83fa92 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -123,6 +123,8 @@ set(GROUPED_CONV2D_FWD wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6540c214cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..5ccf4ebe3f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index e6fab095fb..165fd1ae43 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -79,6 +79,8 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..c20a45cade --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..b3f00c7256 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/large_tensor/device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index a319857a5b..c7f4f66f58 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -8,13 +8,13 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_fwd_scaleadd_ab.cpp) target_link_libraries(test_grouped_convnd_fwd_scaleadd_ab PRIVATE utility device_grouped_conv3d_fwd_scaleadd_ab_instance) - add_executable(test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_fwd_large_cases_xdl.cpp) - target_compile_options(test_grouped_convnd_fwd_large_cases_xdl PRIVATE -Wno-global-constructors -Wno-undef) - target_link_libraries(test_grouped_convnd_fwd_large_cases_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) - add_executable(test_grouped_convnd_fwd_dataset_xdl test_grouped_convnd_fwd_dataset_xdl.cpp) target_compile_options(test_grouped_convnd_fwd_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_fwd_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + + add_executable(test_grouped_convnd_fwd_large_cases test_grouped_convnd_fwd_large_cases.cpp) + target_compile_options(test_grouped_convnd_fwd_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_fwd_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp index c549945d82..c51918e98f 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: MIT #include -#include #include +#include #include #include