diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp index 5798822bb3..7040f74ec0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -59,130 +59,133 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; - - index_t left = 0; - index_t right = gemms_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && - block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - if(block_args_id < gemm_kernel_args[group_id].BlockStart_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block); - if constexpr(GridwiseGemm::DirectLoadEnabled) - { + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if constexpr(GridwiseGemm::DirectLoadEnabled) + { #if defined(__gfx950__) - const auto a_grid_desc_ak0_m_ak1_transformed = - GridwiseGemm::template TransformGrid( - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_); - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) - { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1_transformed, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - k_idx, - gridDim.z, - blockIdx.x - gemm_kernel_args[group_id].BlockStart_); - } - else - { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1_transformed, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - k_idx, - gridDim.z, - blockIdx.x - gemm_kernel_args[group_id].BlockStart_); - } + const auto a_grid_desc_ak0_m_ak1_transformed = + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } #endif - } - else - { - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) - { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - k_idx, - gridDim.z, - blockIdx.x - gemm_kernel_args[group_id].BlockStart_); } else { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - k_idx, - gridDim.z, - blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } } } #else @@ -192,7 +195,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; -#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#endif // end of if (defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)) } } // namespace @@ -968,17 +971,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 { if(arg.stride_overflow) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported! stride_overflow is set but LargeTensors is not enabled!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } - if(get_warp_size() != 64) - { - // TODO: Enable for warp size 32 - return false; - } // Reject if the current warp size has no valid XDL configuration - // Warp size 32 is temporary not supported but leave it for the future if(get_warp_size() == 64) { if constexpr(MXdlPerWave64 == 0) @@ -994,6 +998,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 } } + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + // check device if constexpr(DirectLoad) { @@ -1009,8 +1021,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; + std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; } return false; @@ -1021,8 +1034,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; + std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; } return false; @@ -1638,6 +1652,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 // clang-format off str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3" + << (get_warp_size() != 64 ? "_WmmaPorted" : "") << (DirectLoad ? "_DirectLoad" : "") << (LargeTensors ? "_Large_Tensor" : "") << "<" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index c784827828..22c9cdcfd5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1440,7 +1440,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(!LargeTensors) { if(arg.stride_overflow) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported! stride_overflow is set but LargeTensors is not enabled!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; + } } // check device @@ -2023,6 +2032,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // clang-format off str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + if(get_warp_size() != 64) { + str << "_WmmaPorted"; + } + if constexpr(DirectLoad) { str << "_DirectLoad"; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 8a7fd8bb30..42de974de8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -1482,7 +1482,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(!LargeTensors) { if(arg.stride_overflow) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported! stride_overflow is set but LargeTensors is not enabled!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; + } } namespace ctc = tensor_layout::convolution; diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp index 8eef152327..257d20d9fb 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp @@ -21,11 +21,19 @@ class TestGroupedConvndBwdData : public ::testing::Test using InLayout = std::tuple_element_t<3, Tuple>; std::vector conv_params; - std::vector split_ks{1, 2}; + std::vector split_ks{1}; template void Run() { + if constexpr(std::is_same_v) + { + if(ck::is_wmma_supported()) + { + GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32"; + } + } + EXPECT_FALSE(conv_params.empty()); bool pass = true; for(auto split_k : split_ks) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp index 9c2a216f71..b7c78e80ce 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp @@ -47,6 +47,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test template void Run() { + if constexpr(std::is_same_v) + { + if(ck::is_wmma_supported()) + { + GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32"; + } + } + EXPECT_FALSE(conv_params.empty()); bool pass = true; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp index 6452f345fe..90b948246b 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp @@ -24,6 +24,14 @@ class TestGroupedConvndFwd : public ::testing::Test template void Run() { + if constexpr(std::is_same_v) + { + if(ck::is_wmma_supported()) + { + GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32"; + } + } + EXPECT_FALSE(conv_params.empty()); bool pass = true; for(auto& param : conv_params)