diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp index d483d8adfd..c785642eae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -118,43 +118,101 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const auto num_k_per_block = gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch; - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + if constexpr(GridwiseGemm::DirectLoadEnabled) { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - p_ds_grid_grp, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - gemm_kernel_args[group_id].block_2_ctile_map_, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_m_n_, - gemm_kernel_args[group_id].e_grid_desc_m_n_, - KBatch, - k_idx); - } else { - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - p_ds_grid_grp, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - gemm_kernel_args[group_id].block_2_ctile_map_, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_m_n_, - gemm_kernel_args[group_id].e_grid_desc_m_n_, - KBatch, - k_idx); +#if defined(__gfx950__) + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_), + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } else { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_), + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } +#endif + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_), + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } else { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_), + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } } #else ignore = karg; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 0c2418b84b..ccf85acf22 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -272,7 +272,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } template - __host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc) + __host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc) { if constexpr(!DirectLoad)