diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index f6f354f98e..efb91bd13d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -25,6 +25,8 @@ #include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + namespace ck { namespace tensor_operation { namespace device { @@ -51,6 +53,11 @@ namespace { * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of * pointer offset into \p ComputePtrOffsetOfStridedBatch. * + * MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + * implementation we can avoid copy data to workspace before kernel launch since number of groups is + * runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this + * kernel in the loop. + * * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). @@ -60,17 +67,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -81,25 +84,21 @@ __global__ void const ABDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, EDataType* __restrict__ p_e_grid, + const std::array gemm_kernel_args, + const index_t gemms_count, const AElementwiseOp a_element_op, const BElementwiseOp b_element_op, const CDEElementwiseOp cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -119,43 +118,79 @@ __global__ void DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = DsPointer::Size(); static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map, - KBatch, - k_idx); + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = gemm_kernel_args; + ignore = gemms_count; ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_n; - ignore = block_2_ctile_map; #endif } @@ -239,6 +274,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = 32; + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; static constexpr index_t NumDTensor = DsDataType::Size(); @@ -378,15 +419,58 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemmMultipleD_xdl_cshuffle:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map - using Block2ETileMap = remove_cvref_t< - decltype(GridwiseGemmMultipleD_xdl_cshuffle< - GridwiseGemmMultiDTemplateParams>::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + GroupedGemmBlock2ETileMap block_2_ctile_map, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + + ds_grid_desc_mblock_mperblock_nblock_nperblock_( + ds_grid_desc_mblock_mperblock_nblock_nperblock), + + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + + // block-to-e-tile map + block_2_ctile_map_(block_2_ctile_map), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -589,9 +673,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) { - for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) { for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) @@ -694,36 +782,51 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); - // desc for blockwise copy - a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); - b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + const index_t grid_size_grp = Block2ETileMap::CalculateGridSize( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1)); - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; - block_2_etile_map_container_.push_back(block_2_etile_map); + grid_size += grid_size_grp; - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map, - k_batch_)) + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0), + e_grid_desc_m_n.GetLength(I1)), + BlockStart); + + const auto GemmK = a_grid_desc_m_k.GetLength(I1); + const bool HasMainKBlockLoop = + GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_); + + gemm_kernel_args_[gemms_count_ / + MaxGroupedGemmGroupsNum][gemms_count_ % + MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - - GridwiseGemm:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); - - e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n)); + gemms_grid_size_.push_back(grid_size); + grid_size = 0; } } } } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + // A/B/Ds/E Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0]; @@ -830,31 +933,28 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 void Print() const { - for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) { - std::cout << "a_grid_desc_ak0_m_ak1_container_" - << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; - std::cout << "b_grid_desc_bk0_n_bk1_container_" - << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; static_for<0, NumDTensor, 1>{}([&](auto j) { std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" - << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] - << std::endl; + << ds_grid_desc_m_n_container_[i][j] << std::endl; }); std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" - << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] - << std::endl; + << e_grid_desc_m_n_container_[i] << std::endl; } } // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemmMultipleD_xdl_cshuffle::DsGridPointer - p_ds_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptor for problem definition @@ -865,16 +965,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::vector ds_grid_desc_m_n_container_; std::vector e_grid_desc_m_n_container_; - // tensor descriptor for block-wise copy - std::vector a_grid_desc_ak0_m_ak1_container_; - std::vector b_grid_desc_bk0_n_bk1_container_; - std::vector - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; - std::vector - e_grid_desc_mblock_mperblock_nblock_nperblock_container_; - // block-to-e-tile map - std::vector block_2_etile_map_container_; Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_, elementwise_block_2_ctile_map_transpose_e_; Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_; @@ -903,6 +994,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const index_t k_batch_; index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + bool bwd_needs_zero_out; long_index_t e_space_size_bytes; }; @@ -941,84 +1036,61 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); } - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i], - arg.k_batch_)) - { - throw std::runtime_error("wrong! device_op has invalid setting"); - } - - const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize( - arg.e_grid_desc_m_n_container_[i]); - - const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; const auto clear_workspace = [&]() { - if(arg.bwd_needs_zero_out && i == 0) + if(arg.bwd_needs_zero_out && gemm_set_id == 0) { hip_check_error(hipMemsetAsync( p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); } }; - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - + auto launch_kernel = [&]() { const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype typename GridwiseGemm::DsGridPointer, EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - Block2ETileMap, ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch, - has_main_loop, ElementOp>; - return launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - arg.p_ds_grid_, - p_e_grid, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_grid_desc_ak0_m_ak1_container_[i], - arg.b_grid_desc_bk0_n_bk1_container_[i], - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.block_2_etile_map_container_[i], - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - arg.k_batch_); + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_)) - { - ave_time += launch_kernel(integral_constant{}); - } - else - { - ave_time += launch_kernel(integral_constant{}); - } + ave_time += launch_kernel(); } return ave_time; @@ -1304,14 +1376,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // Gridwise GEMM size - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i], - arg.k_batch_)) + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) { return false; } diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index dcc07d8a49..7eca68bbf8 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -271,6 +271,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) @@ -870,6 +871,7 @@ struct OffsettedBlockToCTileMap { using underlying_type = UnderlyingBlockToCTileMap; + __host__ __device__ OffsettedBlockToCTileMap() = default; __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 6cd8440e58..12f6ad606f 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -186,8 +186,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, rtol = std::max(rtol, rtol_split_k); atol = std::max(atol, atol_split_k); - pass = pass & ck::utils::check_err( - in_device, in_host, "Error: Incorrect results!", rtol, atol); + pass &= ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol << std::endl; diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index ca9b2f1d24..c1bb90dd9c 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -261,8 +261,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ck::utils::get_absolute_threshold( max_accumulated_value, num_accums_split_k); // Use higher threshold - rtol = std::max(rtol, rtol_split_k); - atol = std::max(atol, atol_split_k); + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + // Use default atol for splitK == 1 bool pass = ck::utils::check_err(weight_device_result, weight_host_result, "Error: Incorrect results!", diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp index 7f8f64c2e2..209b9b4f55 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp @@ -96,6 +96,18 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) { this->conv_params.clear(); + // GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32 + this->conv_params.push_back( + {2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); + // GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_params.push_back( + {2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}}); + // GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_params.push_back( + {2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}}); + // GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32 + this->conv_params.push_back( + {2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( {2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back(