diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 7bc3be1a95..bbf62d5fbe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -63,11 +63,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3b8be8bf8..30c1b1d490 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -62,10 +62,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif @@ -1028,6 +1025,17 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { return false; } + + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 7f1669cf13..843705692b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -63,28 +63,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - num_k_per_block, - karg, - epilogue_args); + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + num_k_per_block, + karg, + epilogue_args); +#if defined(__gfx11__) + } +#endif #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -1179,6 +1185,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 return false; } + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + // Check this here, it allows to use other instances from factory even // if workspace is not allocated if(!arg.p_workspace_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 213b72050e..c070d8d9e9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -64,11 +64,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< @@ -1089,18 +1085,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 return false; } - if constexpr(std::is_same_v || - std::is_same_v) + if(gemm_arg.KBatch > 1 && ck::is_gfx11_supported()) { - if(gemm_arg.KBatch > 1 && ck::is_gfx11_supported()) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Unsupported splitK on gfx11." << std::endl; - } - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; + std::cout << "Unsupported splitK on gfx11." << std::endl; } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; } if constexpr(std::is_same_v || std::is_same_v ||