From 3864685a52c1aafa8f0735be858cc03e2979dcbe Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 23 May 2022 12:10:22 -0500 Subject: [PATCH] fix build (#246) * fix build * Revert "fix build" This reverts commit d73102384bfbb609e487d6d0cd04a3c8c9c4ec9e. * post PR #235 merge fix * amend Co-authored-by: Anthony Chang [ROCm/composable_kernel commit: ba58a93f606447bf9c6cf8e616683b4862567917] --- ...rd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 20 +++--- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 68 +++++-------------- 2 files changed, 25 insertions(+), 63 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 386356cc84..96a86b39db 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -802,17 +802,16 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_m_n_, - M01_, - N01_)) + block_2_ctile_map_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); } } @@ -871,14 +870,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); @@ -1066,8 +1065,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 6ada231547..0d3f8ddefb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -4,6 +4,7 @@ #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_xdlops.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp" @@ -495,12 +496,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -532,31 +533,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { // const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -588,37 +573,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(KBatch), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); } __host__ __device__ static constexpr auto @@ -667,6 +623,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const index_t k_batch_id = block_work_idx[I0]; + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);