diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 1ab3a124fa..a2b6520be6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -33,6 +33,111 @@ namespace ck { namespace tensor_operation { namespace device { +// Dispatch helper function for split-K hack - handles 4-way dispatch based on runtime flags +template +__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack, + index_t k_batch) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } + else + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } +} + template (p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map, - split_k_stride_a, - split_k_stride_b, - split_k_offset_a_hack, - split_k_offset_b_hack, - k_batch); + DispatchBatchedGemmSplitKHack( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + split_k_offset_a_hack, + split_k_offset_b_hack, + k_batch); } #else ignore = p_a_grid; @@ -474,7 +591,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch<>, - false>, // Both true/false give the same occupancy. + false>, // HasMainKBlockLoop - both true/false give the same occupancy BlockSize, dynamic_smem_size)); return std::max(1, max_occupancy); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 39caf76b63..85e35b0442 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -149,7 +150,9 @@ template + bool HasMainKBlockLoop, + bool SplitKOffsetAHack, + bool SplitKOffsetBHack> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -167,8 +170,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CBlockClusterAdaptor c_block_cluster_adaptor, const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack, index_t k_batch) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ @@ -177,22 +178,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor, - split_k_stride_a, - split_k_stride_b, - split_k_offset_a_hack, - split_k_offset_b_hack, - k_batch); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor, + split_k_stride_a, + split_k_stride_b, + k_batch); } #else ignore = p_a_grid; @@ -207,8 +207,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = c_block_cluster_adaptor; ignore = split_k_stride_a; ignore = split_k_stride_b; - ignore = split_k_offset_a_hack; - ignore = split_k_offset_b_hack; ignore = k_batch; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -551,7 +549,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight MRepeat, NRepeat, FloatC, - CGlobalMemoryDataOperation>(); + CGlobalMemoryDataOperation_>(); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template @@ -661,7 +659,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); - template + template __device__ static void Run(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -676,8 +676,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const CBlockClusterAdaptor& c_block_cluster_adaptor, const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, - bool split_k_offset_a_hack, - bool split_k_offset_b_hack, index_t k_batch) { const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); @@ -688,20 +686,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const index_t k_batch_id = block_work_idx[I0]; - const long_index_t split_k_offset_a = - split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0; - const long_index_t split_k_offset_b = - split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0; + // Use compile-time branching based on template parameters + const long_index_t split_k_offset_a = SplitKOffsetAHack ? k_batch_id * split_k_stride_a : 0; + const long_index_t split_k_offset_b = SplitKOffsetBHack ? k_batch_id * split_k_stride_b : 0; // When hack is enabled, buffer size equals the stride (calculated from descriptor's // CalculateOffset method in the device layer). This properly accounts for the // descriptor's transform pipeline and non-compact strides. // When hack is disabled, use the full element space size. const long_index_t a_buffer_size = - split_k_offset_a_hack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize(); + SplitKOffsetAHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize(); const long_index_t b_buffer_size = - split_k_offset_b_hack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize(); + SplitKOffsetBHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize(); ignore = k_batch; // k_batch value itself not used in this function @@ -763,7 +760,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight true>( a_b_k0_m_k1_grid_desc, make_multi_index( - split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0), + SplitKOffsetAHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0), a_element_op, a_b_k0_m_k1_block_desc, make_multi_index(0, 0, 0, 0), @@ -794,7 +791,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight true>( b_b_k0_n_k1_grid_desc, make_multi_index( - split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0), + SplitKOffsetBHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0), b_element_op, b_b_k0_n_k1_block_desc, make_multi_index(0, 0, 0, 0),