diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index f89d4c6c53..5c8b18d554 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -24,6 +24,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -716,7 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle k_batch_ = split_k; } - // Step 1: Create initial descriptors with hack=false to check compactness + // Create initial descriptors with hack=false to check compactness const auto descs_initial = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -738,17 +739,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle false, // hack=false for initial check true); // use_full_batch_kindex - // Step 2: Check if descriptors are compact (element_space == product of dimensions) - const auto a_dims_product = static_cast(descs_initial[I0].GetLength(I0)) * - static_cast(descs_initial[I0].GetLength(I1)) * - static_cast(descs_initial[I0].GetLength(I2)); - const auto b_dims_product = static_cast(descs_initial[I1].GetLength(I0)) * - static_cast(descs_initial[I1].GetLength(I1)) * - static_cast(descs_initial[I1].GetLength(I2)); - - const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product); - const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product); - ce_elementwise_grid_desc_m_n_ = conv_to_gemm_transformer_v1 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -767,35 +757,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_)[I2]; - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + KPerBlock); - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0; - - const bool can_divide_n_spatial_by_k_batch = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0; - - const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0; - - const bool is_correct_layout = - is_NSpatialGC_GKSpatial_NSpatialGK(); - - const bool is_a_stride_divisible = - descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0; - - const bool is_b_stride_divisible = - descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0; - - // Step 3: Determine if hack can be enabled (only for compact layouts) - split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && - is_k_not_paded && is_correct_layout && is_a_stride_divisible && - is_a_compact; - - split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && - is_correct_layout && is_b_stride_divisible && is_b_compact; - - // Step 4: Create final descriptors with correct hack flags + // Create final descriptors with correct hack flags const auto descs = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index ae3d15ccf8..996df9aba5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -612,48 +613,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle false, // split_k_offset_a_hack (temporary) false); // split_k_offset_b_hack (temporary) - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0; - - const bool can_divide_n_spatial_by_k_batch = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0; - - const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0; - - const bool is_correct_layout = - is_NSpatialGC_GKSpatial_NSpatialGK(); - - const bool is_a_stride_divisible = - descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0; - - const bool is_b_stride_divisible = - descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0; - - // Check if descriptor has compact layout (product of dimensions equals element space) - // Non-compact layouts have complex transform pipelines that don't support the hack - const auto a_dims_product = static_cast(descs_initial[I0].GetLength(I0)) * - static_cast(descs_initial[I0].GetLength(I1)) * - static_cast(descs_initial[I0].GetLength(I2)) * - static_cast(descs_initial[I0].GetLength(I3)); - const auto b_dims_product = static_cast(descs_initial[I1].GetLength(I0)) * - static_cast(descs_initial[I1].GetLength(I1)) * - static_cast(descs_initial[I1].GetLength(I2)) * - static_cast(descs_initial[I1].GetLength(I3)); - - const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product); - const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product); - - // Determine if we can safely use the split-k offset hack - // Only enable for compact layouts where element_space_size == product of dimensions - split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && - is_k_not_paded && is_correct_layout && is_a_stride_divisible && - is_a_compact; - - split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && - is_correct_layout && is_b_stride_divisible && is_b_compact; + std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock * K1); // Now create descriptors with the correct hack flags const auto descs = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index ed602295dc..904a3588d6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -607,47 +608,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 false, // split_k_offset_b_hack (temporary) true); // use_full_batch_kindex=true for V1-compatible descriptors - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0; - - const bool can_divide_n_spatial_by_k_batch = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0; - - const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0; - - const bool is_correct_layout = - is_NSpatialGC_GKSpatial_NSpatialGK(); - - const bool is_a_stride_divisible = - descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0; - - const bool is_b_stride_divisible = - descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0; - - // Check if descriptor has compact layout (product of dimensions equals element space) - // Non-compact layouts have complex transform pipelines that don't support the hack - // Note: CShuffleV3 descriptors are 3D [K0, M, K1], not 4D like CShuffle - const auto a_dims_product = static_cast(descs_initial[I0].GetLength(I0)) * - static_cast(descs_initial[I0].GetLength(I1)) * - static_cast(descs_initial[I0].GetLength(I2)); - const auto b_dims_product = static_cast(descs_initial[I1].GetLength(I0)) * - static_cast(descs_initial[I1].GetLength(I1)) * - static_cast(descs_initial[I1].GetLength(I2)); - - const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product); - const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product); - - // Determine if we can safely use the split-k offset hack - // Only enable for compact layouts where element_space_size == product of dimensions - split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && - is_k_not_paded && is_correct_layout && is_a_stride_divisible && - is_a_compact; - - split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && - is_correct_layout && is_b_stride_divisible && is_b_compact; + std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock); // Now create descriptors with the correct hack flags const auto descs = diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp new file mode 100644 index 0000000000..cafe299c1a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Check if a tensor descriptor has compact layout +// Compact means: GetElementSpaceSize() == product of all dimension lengths +// Non-compact descriptors have complex transform pipelines that may not support split-k hack +template +bool IsDescriptorCompact(const Descriptor& desc) +{ + // Calculate product of all dimensions + long_index_t dims_product = 1; + constexpr index_t num_dims = Descriptor::GetNumOfDimension(); + + // Use template recursion to multiply all dimension lengths + static_for<0, num_dims, 1>{}( + [&](auto i) { dims_product *= static_cast(desc.GetLength(i)); }); + + return desc.GetElementSpaceSize() == dims_product; +} + +// Determine split-k hack eligibility for descriptor pair +// This checks all the conditions required for safely using the split-k offset hack +template +struct SplitKHackEligibility +{ + template + static auto + Check(const ADescriptor& a_desc, + const BDescriptor& b_desc, + index_t k_batch, + index_t Conv_N, + const std::array& output_spatial_lengths, + index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage + { + // Only enable hack if k_batch > 1 + if(k_batch <= 1) + { + return std::make_pair(false, false); + } + + // Calculate output spatial product + const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(), + output_spatial_lengths.end(), + index_t{1}, + std::multiplies()); + + // Check various divisibility and layout requirements + const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0; + + const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0; + + const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0; + + const bool is_correct_layout = + is_NSpatialGC_GKSpatial_NSpatialGK(); + + const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0; + + const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0; + + // Check descriptor compactness + const bool is_a_compact = IsDescriptorCompact(a_desc); + const bool is_b_compact = IsDescriptorCompact(b_desc); + + // Determine hack flags based on all conditions + const bool split_k_offset_a_hack = can_divide_n_spatial_by_k_batch && is_k_not_paded && + is_correct_layout && is_a_stride_divisible && + is_a_compact; + + const bool split_k_offset_b_hack = can_divide_n_by_k_batch && is_k_not_paded && + is_correct_layout && is_b_stride_divisible && + is_b_compact; + + return std::make_pair(split_k_offset_a_hack, split_k_offset_b_hack); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck