diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 4a528831ad..b011356fca 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -24,7 +24,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" -#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -35,225 +35,6 @@ namespace ck { namespace tensor_operation { namespace device { -// Helper function to dispatch split-K hack for standard kernel (single LDS) -template -__device__ void DispatchSplitKHack(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared, - const typename GridwiseGemm::Argument& karg, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id, - index_t k_batch, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack) -{ - if(split_k_offset_a_hack && split_k_offset_b_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_a_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_b_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } -} - -// Helper function to dispatch split-K hack for 2lds kernel -template -__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const typename GridwiseGemm::Argument& karg, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id, - index_t k_batch, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack) -{ - if(split_k_offset_a_hack && split_k_offset_b_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_a_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_b_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } -} - template -__device__ void DispatchSplitKHack(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared, - const typename GridwiseGemm::Argument& karg, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id, - index_t k_batch, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack) -{ - if(split_k_offset_a_hack && split_k_offset_b_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_a_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_b_hack) - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else - { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } -} - -// Helper function to dispatch split-K hack for 2lds kernel -template -__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const typename GridwiseGemm::Argument& karg, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id, - index_t k_batch, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack) -{ - if(split_k_offset_a_hack && split_k_offset_b_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_a_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else if(split_k_offset_b_hack) - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } - else - { - GridwiseGemm::template Run_2Lds(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_id, - k_batch); - } -} - template -#include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// Check if a tensor descriptor has compact layout -// Compact means: GetElementSpaceSize() == product of all dimension lengths -// Non-compact descriptors have complex transform pipelines that may not support split-k hack -template -bool IsDescriptorCompact(const Descriptor& desc) -{ - // Calculate product of all dimensions - long_index_t dims_product = 1; - constexpr index_t num_dims = Descriptor::GetNumOfDimension(); - - // Use template recursion to multiply all dimension lengths - static_for<0, num_dims, 1>{}( - [&](auto i) { dims_product *= static_cast(desc.GetLength(i)); }); - - return desc.GetElementSpaceSize() == dims_product; -} - -// Determine split-k hack eligibility for descriptor pair -// This checks all the conditions required for safely using the split-k offset hack -template -struct SplitKHackEligibility -{ - template - static auto - Check(const ADescriptor& a_desc, - const BDescriptor& b_desc, - index_t k_batch, - index_t Conv_N, - const std::array& output_spatial_lengths, - index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage - { - // Only enable hack if k_batch > 1 - if(k_batch <= 1) - { - return std::make_pair(false, false); - } - - // Calculate output spatial product - const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(), - output_spatial_lengths.end(), - index_t{1}, - std::multiplies()); - - // Check various divisibility and layout requirements - const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0; - - const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0; - - const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0; - - const bool is_correct_layout = - is_NSpatialGC_GKSpatial_NSpatialGK(); - - const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0; - - const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0; - - // Check descriptor compactness - const bool is_a_compact = IsDescriptorCompact(a_desc); - const bool is_b_compact = IsDescriptorCompact(b_desc); - - // Determine hack flags based on all conditions - const bool split_k_offset_a_hack = can_divide_n_spatial_by_k_batch && is_k_not_paded && - is_correct_layout && is_a_stride_divisible && - is_a_compact; - - const bool split_k_offset_b_hack = can_divide_n_by_k_batch && is_k_not_paded && - is_correct_layout && is_b_stride_divisible && - is_b_compact; - - return std::make_pair(split_k_offset_a_hack, split_k_offset_b_hack); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp new file mode 100644 index 0000000000..bdce7fe6a4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Check if a tensor descriptor has compact layout +// Compact means: GetElementSpaceSize() == product of all dimension lengths +// Non-compact descriptors have complex transform pipelines that may not support split-k hack +template +bool IsDescriptorCompact(const Descriptor& desc) +{ + // Calculate product of all dimensions + long_index_t dims_product = 1; + constexpr index_t num_dims = Descriptor::GetNumOfDimension(); + + // Use template recursion to multiply all dimension lengths + static_for<0, num_dims, 1>{}( + [&](auto i) { dims_product *= static_cast(desc.GetLength(i)); }); + + return desc.GetElementSpaceSize() == dims_product; +} + +// Determine split-k hack eligibility for descriptor pair +// This checks all the conditions required for safely using the split-k offset hack +template +struct SplitKHackEligibility +{ + template + static auto + Check(const ADescriptor& a_desc, + const BDescriptor& b_desc, + index_t k_batch, + index_t Conv_N, + const std::array& output_spatial_lengths, + index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage + { + // Only enable hack if k_batch > 1 + if(k_batch <= 1) + { + return std::make_pair(false, false); + } + + // Calculate output spatial product + const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(), + output_spatial_lengths.end(), + index_t{1}, + std::multiplies()); + + // Check various divisibility and layout requirements + const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0; + + const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0; + + const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0; + + const bool is_correct_layout = + is_NSpatialGC_GKSpatial_NSpatialGK(); + + const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0; + + const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0; + + // Check descriptor compactness + const bool is_a_compact = IsDescriptorCompact(a_desc); + const bool is_b_compact = IsDescriptorCompact(b_desc); + + // Determine hack flags based on all conditions + const bool split_k_offset_a_hack = can_divide_n_spatial_by_k_batch && is_k_not_paded && + is_correct_layout && is_a_stride_divisible && + is_a_compact; + + const bool split_k_offset_b_hack = can_divide_n_by_k_batch && is_k_not_paded && + is_correct_layout && is_b_stride_divisible && + is_b_compact; + + return std::make_pair(split_k_offset_a_hack, split_k_offset_b_hack); + } +}; + +// Helper function to dispatch split-K hack for standard kernel (single LDS) +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +// Helper function to dispatch split-K hack for 2lds kernel +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck