diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 3e8a0fd3fb..211496b3ff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -24,6 +24,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; #endif // end of if (defined(__gfx9__)) } @@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; + ignore = split_k_offset_hack; + ignore = split_k_stride_a; + ignore = split_k_stride_b; #endif // end of if (defined(__gfx9__)) } @@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle k_batch_ = split_k; } - const auto descs = + // Create initial descriptors with hack=false to check compactness + const auto descs_initial = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( Conv_N_, @@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - ce_grid_desc_m_n_ = descs[I2]; + k_batch_, + false, // hack=false for initial check + true); // use_full_batch_kindex ce_elementwise_grid_desc_m_n_ = conv_to_gemm_transformer_v1 @@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_)[I2]; + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + KPerBlock); + + // Create final descriptors with correct hack flag + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + split_k_offset_hack_, // Use determined hack flag + true); // use_full_batch_kindex + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + // Step 5: Calculate stride using CalculateOffset on FINAL descriptors + if(split_k_offset_hack_) + { + const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_; + const auto idx_start = make_multi_index(0, 0, 0); + const auto idx_next = make_multi_index(k0_per_batch, 0, 0); + split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) - + a_grid_desc_k0_m_k1_.CalculateOffset(idx_start); + } + else + { + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + } + + if(split_k_offset_hack_) + { + const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_; + const auto idx_start = make_multi_index(0, 0, 0); + const auto idx_next = make_multi_index(k0_per_batch, 0, 0); + split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) - + b_grid_desc_k0_n_k1_.CalculateOffset(idx_start); + } + else + { + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + } + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); @@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } else { @@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } }; @@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } - constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && - arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) - { - return false; - } - return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 42ad21dafe..976b6f1ef8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -33,6 +34,74 @@ namespace ck { namespace tensor_operation { namespace device { +// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag +template +__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack, + index_t k_batch) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } + else + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } +} + template (p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + DispatchBatchedGemmSplitKHack( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + split_k_offset_hack, + k_batch); } #else ignore = p_a_grid; @@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = batch_count; ignore = block_2_ctile_map; ignore = compute_ptr_offset_of_batch; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; + ignore = k_batch; compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); @@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch<>, - false>, // Both true/false give the same occupancy. + false>, // HasMainKBlockLoop - both true/false give the same occupancy BlockSize, dynamic_smem_size)); return std::max(1, max_occupancy); @@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle k_batch_ = split_k; } + // Create descriptors first (with hack flags temporarily set to false) + // so we can check if element space sizes are divisible by k_batch + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + false); // split_k_offset_b_hack (temporary) + + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock * K1); + + // Now create descriptors with the correct hack flag const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_hack_); a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + // Calculate stride using CalculateOffset method for accurate stride + // This works correctly for any descriptor transform pipeline + split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_b_ /= k_batch_; + block_2_ctile_map_ = GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); @@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_batch_, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_, + arg.k_batch_); }; if(has_main_k0_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 9df78f55e5..2121be00d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; @@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; + #endif // end of if (defined(__gfx9__) } @@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; @@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; #endif // end of if (defined(__gfx9__) } @@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 k_batch_ = split_k; } + // Create descriptors first (with hack flags temporarily set to false) + // so we can check if element space sizes match product of dimensions + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + false, // split_k_offset_b_hack (temporary) + true); // use_full_batch_kindex=true for V1-compatible descriptors + + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock); + + // Now create descriptors with the correct hack flag const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_hack_, + true); // use_full_batch_kindex=true for V1-compatible descriptors - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // Calculate stride using CalculateOffset method for accurate stride + // This works correctly for any descriptor transform pipeline + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_b_ /= k_batch_; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; @@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl; - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 template float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { if(arg.k_batch_ > 1) @@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } else { @@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } }; @@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, @@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be One to Seven else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { @@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be Odd or Even else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, @@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } #endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(is_same_v || is_same_v) { @@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && - arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB)) + const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() / + (arg.split_k_offset_hack_ ? arg.k_batch_ : 1) * + sizeof(ADataType) <= + TwoGB; + const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() / + (arg.split_k_offset_hack_ ? arg.k_batch_ : 1) * + sizeof(BDataType) <= + TwoGB; + const bool c_small_enough = + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB; + if(!(a_small_enough && b_small_enough && c_small_enough)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp new file mode 100644 index 0000000000..6fe4257dbb --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp @@ -0,0 +1,222 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Check if a tensor descriptor has compact layout +// Compact means: GetElementSpaceSize() == product of all dimension lengths +// Non-compact descriptors have complex transform pipelines that may not support split-k hack +template +bool IsDescriptorCompact(const Descriptor& desc) +{ + // Calculate product of all dimensions + long_index_t dims_product = 1; + constexpr index_t num_dims = Descriptor::GetNumOfDimension(); + + // Use template recursion to multiply all dimension lengths + static_for<0, num_dims, 1>{}( + [&](auto i) { dims_product *= static_cast(desc.GetLength(i)); }); + + return desc.GetElementSpaceSize() == dims_product; +} + +// Determine split-k hack eligibility for descriptor pair +// This checks all the conditions required for safely using the split-k offset hack +template +struct SplitKHackEligibility +{ + template + static bool + Check(const ADescriptor& a_desc, + const BDescriptor& b_desc, + index_t k_batch, + index_t Conv_N, + const std::array& output_spatial_lengths, + index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage + { + // Only enable hack if k_batch > 1 + if(k_batch <= 1) + { + return false; + } + + // Calculate output spatial product + const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(), + output_spatial_lengths.end(), + index_t{1}, + std::multiplies()); + + // Check various divisibility and layout requirements + const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0; + + const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0; + + const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0; + + const bool is_correct_layout = + is_NSpatialGC_GKSpatial_NSpatialGK(); + + const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0; + + const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0; + + // Check descriptor compactness + const bool is_a_compact = IsDescriptorCompact(a_desc); + const bool is_b_compact = IsDescriptorCompact(b_desc); + + // Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch + // The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must + // apply the hack uniformly to both tensors to maintain kernel applicability + const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch && + is_k_not_paded && is_correct_layout && is_a_stride_divisible && + is_b_stride_divisible && is_a_compact && is_b_compact; + + return eligible; + } +}; + +// Helper function to dispatch split-K hack for standard kernel (single LDS) +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_hack) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +// Helper function to dispatch split-K hack for 2lds kernel +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_hack) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 9339916d6f..8188c42ca5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -663,7 +663,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetHack = false> __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -673,12 +674,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0) + const index_t k_id = 0, + const index_t k_batch = 1) { + const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -744,7 +749,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -775,7 +780,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1024,7 +1029,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetHack = false> __device__ static void Run_2Lds(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -1035,12 +1041,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0) + const index_t k_id = 0, + const index_t k_batch = 1) { + const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1106,7 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1137,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 6fd6529fbb..e6f055d183 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -149,7 +150,8 @@ template + bool HasMainKBlockLoop, + bool SplitKOffsetHack> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -164,7 +166,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + const CBlockClusterAdaptor c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + index_t k_batch) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ defined(__gfx12__) @@ -172,17 +177,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor, + split_k_stride_a, + split_k_stride_b, + k_batch); } #else ignore = p_a_grid; @@ -195,6 +204,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = b_element_op; ignore = c_element_op; ignore = c_block_cluster_adaptor; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = k_batch; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -536,7 +548,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight MRepeat, NRepeat, FloatC, - CGlobalMemoryDataOperation>(); + CGlobalMemoryDataOperation_>(); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template @@ -646,6 +658,416 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + index_t k_batch) + { + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + // Use compile-time branching based on template parameters + const long_index_t split_k_offset_a = SplitKOffsetHack ? k_batch_id * split_k_stride_a : 0; + const long_index_t split_k_offset_b = SplitKOffsetHack ? k_batch_id * split_k_stride_b : 0; + + // When hack is enabled, buffer size equals the stride (calculated from descriptor's + // CalculateOffset method in the device layer). This properly accounts for the + // descriptor's transform pipeline and non-compact strides. + // When hack is disabled, use the full element space size. + const long_index_t a_buffer_size = + SplitKOffsetHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize(); + + const long_index_t b_buffer_size = + SplitKOffsetHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize(); + + ignore = k_batch; // k_batch value itself not used in this function + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid + split_k_offset_a, a_buffer_size); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + split_k_offset_b, b_buffer_size); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + FloatAAdjusted, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + FloatBAdjusted, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + K1 <= 4) || + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXdl, ""); + static_assert(N2 == NPerXdl, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } + template __device__ static void Run(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index 266ffb5fae..3379fb2c59 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -149,7 +149,8 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) // Deprecated parameter for backward compatibility { using namespace ck; @@ -172,7 +173,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -190,7 +192,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -208,7 +210,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -246,7 +248,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -285,7 +287,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -323,7 +325,8 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) { using namespace ck; @@ -359,7 +362,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -378,7 +382,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -393,7 +397,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -422,7 +426,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -463,7 +467,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -497,7 +501,8 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) { using namespace ck; @@ -540,7 +545,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -559,7 +565,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -574,7 +580,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -603,7 +609,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -653,7 +659,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 96482b1412..94eae555e9 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -353,7 +355,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); @@ -373,7 +378,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -389,7 +394,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -419,7 +424,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -460,7 +465,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -495,7 +500,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -531,7 +538,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -551,7 +561,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -567,7 +577,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -597,7 +607,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -647,7 +657,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -681,7 +691,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -724,7 +736,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -744,7 +759,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -760,7 +775,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -790,7 +805,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -855,7 +870,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index b76d957044..07388c4847 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -111,6 +111,101 @@ __device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) return vy.template AsType()[I0]; } +#if defined(__gfx11__) +template <> +__device__ float8_t atomic_add(float8_t* p_dst, const float8_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + vy.template AsType()(I2) = + atomicAdd(c_style_pointer_cast(p_dst) + 2, vx.template AsType()[I2]); + vy.template AsType()(I3) = + atomicAdd(c_style_pointer_cast(p_dst) + 3, vx.template AsType()[I3]); + vy.template AsType()(I4) = + atomicAdd(c_style_pointer_cast(p_dst) + 4, vx.template AsType()[I4]); + vy.template AsType()(I5) = + atomicAdd(c_style_pointer_cast(p_dst) + 5, vx.template AsType()[I5]); + vy.template AsType()(I6) = + atomicAdd(c_style_pointer_cast(p_dst) + 6, vx.template AsType()[I6]); + vy.template AsType()(I7) = + atomicAdd(c_style_pointer_cast(p_dst) + 7, vx.template AsType()[I7]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ half4_t atomic_add(half4_t* p_dst, const half4_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ half8_t atomic_add(half8_t* p_dst, const half8_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + vy.template AsType()(I4) = atomic_add(c_style_pointer_cast(p_dst) + 4, + vx.template AsType()[I4]); + vy.template AsType()(I5) = atomic_add(c_style_pointer_cast(p_dst) + 5, + vx.template AsType()[I5]); + vy.template AsType()(I6) = atomic_add(c_style_pointer_cast(p_dst) + 6, + vx.template AsType()[I6]); + vy.template AsType()(I7) = atomic_add(c_style_pointer_cast(p_dst) + 7, + vx.template AsType()[I7]); + + return vy.template AsType()[I0]; +} +#endif // defined(__gfx11__) + // Caution: DO NOT REMOVE // intentionally have only declaration but no definition to cause compilation failure when trying to // instantiate this template. The purpose is to make the implementation of atomic_max explicit for