From 7a98b5d00247b61095d32ce6852bb2851b584a0a Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Fri, 20 Feb 2026 23:56:29 +0800 Subject: [PATCH] [ck] Support VGPR estimate in GridwiseGemm_wmma_cshuffle_v3 (#4638) 1. Add GetEstimateVgprCount to estimate the VGPR usage in GridwiseGemm_wmma_cshuffle_v3 2. Add IsValidCompilationParameter to disable kernel which use too many vgprs. - Currently, the threashold is AvailableVgprCount * 1.25 3. Modify examples to avoid test is disabled on gfx11 It is port from internal repo PR[#192](https://github.com/ROCm/composable_kernel/issues/192) ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: illsilin_amdeng --- .../gemm_wmma_splitk_reduce_multi_d_fp16.cpp | 4 +- ...ultiply_multiply_wmma_fp16_bpreshuffle.cpp | 6 +- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 9 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 820 ++++++++++-------- 4 files changed, 465 insertions(+), 374 deletions(-) diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp index ae5bf950a7..6cb0fb2106 100644 --- a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp @@ -33,10 +33,10 @@ using DeviceGemmV2Instance = ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, - 128, 256, 64, + 128, 128, 64, 8, 8, 16, 16, - 4, 4, + 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp index 5e0851dbb0..90cc3f7985 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp @@ -65,12 +65,12 @@ using DeviceOpInstance = A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, - 32, 128, 128, + 32, 128, 64, 8, 8, 16, 16, 2, 2, - S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index dfdfd53725..854e78851f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -1785,12 +1785,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 p_ds_grid_dummy[i] = nullptr; StrideDs_dummy[i] = I0; }); - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); i++) { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); // Create gemm arguments with dummy values to check for validity typename GridwiseGemmCTranspose::Argument gemm_arg{ std::array{nullptr}, // p_as_grid diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index a1cba118b2..8d9935833d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -5,6 +5,7 @@ #include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -568,6 +569,82 @@ struct GridwiseGemm_wmma_cshuffle_v3 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + template + static constexpr index_t GetEstimateVgprCount() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in buffer load and LDS store + constexpr index_t TempVgprCount = BaseInputVgprCount / 2; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return InputVgprCount + TempVgprCount + AccVgprCount; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return InputVgprCount * 2 + TempVgprCount + AccVgprCount; + } + else + { + static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1, + "Invalid pipeline version"); + } + } + + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch()); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } + else + { + return true; + } + } + + template + __host__ static bool CheckValidity(const Argument& karg, bool allow_short_v3_pipe = false) + { + const auto availableVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx_invalid_t{}); + } + }(); + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + return false; + } + + return Base::template CheckValidity(karg, allow_short_v3_pipe); + } __device__ static index_t GetKBlockPerScale() { return 1; } template {}]); + const index_t block_n_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); + + // BScale struct (Empty) + using Scale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = Scale{}; + auto b_scale_struct = Scale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args, + A_k_id, + B_k_id); } - - const index_t block_m_id = - __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); - const index_t block_n_id = - __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); - - // BScale struct (Empty) - using Scale = typename BlockwiseGemmPipe::Empty; - auto a_scale_struct = Scale{}; - auto b_scale_struct = Scale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args, - A_k_id, - B_k_id); } template {}([&](auto i) { - using ADataType_ = remove_cvref_t>; - p_as_grid_(i) = - static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; - }); - - BsGridPointer p_bs_grid_; - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = - static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; - }); - - DsGridPointer p_ds_grid_grp; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; }); - - // Currently supporting one A and one B - const auto as_grid_desc_ak0_m_ak1 = generate_tuple( - [&](auto i) { - ignore = i; - return a_grid_desc_ak0_m_ak1; - }, - Number{}); - - const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( - [&](auto i) { - ignore = i; - return b_grid_desc_bk0_n_bk1; - }, - Number{}); - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + if constexpr(IsValidCompilationParameter()) { - return; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane( + (blockIdx.z - n_idx * karg.KBatch) * num_k_per_block); + + // offset base pointer for each work-group + const long_index_t a_batch_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 + : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) + : 0; + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = + static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; + }); + + DsGridPointer p_ds_grid_grp; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; }); + + // Currently supporting one A and one B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // AScale struct (Empty) + using AScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = AScale{}; + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_grp, + karg.p_e_grid + e_batch_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args, + k_idx, + k_idx, + karg.KBatch); } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // AScale struct (Empty) - using AScale = typename BlockwiseGemmPipe::Empty; - auto a_scale_struct = AScale{}; - - // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid_, - p_bs_grid_, - p_ds_grid_grp, - karg.p_e_grid + e_batch_offset + e_n_offset, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args, - k_idx, - k_idx, - karg.KBatch); } // Run method for convolution (grid descriptors are passed as arguments, @@ -1005,103 +1091,106 @@ struct GridwiseGemm_wmma_cshuffle_v3 Argument& karg, EpilogueArgument& epilogue_args) { - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - AsGridPointer p_as_grid_; - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType_ = remove_cvref_t>; - p_as_grid_(i) = static_cast(karg.p_as_grid[i]) + a_batch_offset; - }); - - BsGridPointer p_bs_grid_; - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; - }); - - const auto ds_grid_desc_m_n = - MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, karg.MBlock, karg.NBlock); - - const auto as_grid_desc_ak0_m_ak1 = generate_tuple( - [&](auto i) { - ignore = i; - return a_grid_desc_ak0_m_ak1; - }, - Number{}); - - const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( - [&](auto i) { - ignore = i; - return b_grid_desc_bk0_n_bk1; - }, - Number{}); - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + if constexpr(IsValidCompilationParameter()) { - return; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = static_cast(karg.p_as_grid[i]) + a_batch_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; + }); + + const auto ds_grid_desc_m_n = + MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, karg.MBlock, karg.NBlock); + + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // Scale structs (Empty) + using Scale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = Scale{}; + auto a_scale_struct = Scale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + karg.p_ds_grid, + karg.p_e_grid + e_batch_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args, + k_idx, + k_idx, + karg.KBatch); } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // Scale structs (Empty) - using Scale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = Scale{}; - auto a_scale_struct = Scale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid_, - p_bs_grid_, - karg.p_ds_grid, - karg.p_e_grid + e_batch_offset, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_desc_mblock_mperblock_nblock_nperblock, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args, - k_idx, - k_idx, - karg.KBatch); } // Run method for convolution fwd (grid descriptors are passed as arguments, @@ -1129,117 +1218,120 @@ struct GridwiseGemm_wmma_cshuffle_v3 Argument& karg, EpilogueArgument& epilogue_args) { - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); - // offset base pointer for each work-group - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t b_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - - AsGridPointer p_as_grid_; - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType_ = remove_cvref_t>; - p_as_grid_(i) = - static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; - }); - - BsGridPointer p_bs_grid_; - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = - static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; - }); - - DsGridPointer p_ds_grid_grp; - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType_ = remove_cvref_t>; - p_ds_grid_grp(i) = static_cast(karg.p_ds_grid[i]) + - ds_batch_offset[i] + ds_n_offset[i]; - }); - - // Currently supporting one A and one B - const auto as_grid_desc_ak0_m_ak1 = generate_tuple( - [&](auto i) { - ignore = i; - return a_grid_desc_ak0_m_ak1; - }, - Number{}); - - const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( - [&](auto i) { - ignore = i; - return b_grid_desc_bk0_n_bk1; - }, - Number{}); - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + if constexpr(IsValidCompilationParameter()) { - return; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); + // offset base pointer for each work-group + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = + static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; + }); + + DsGridPointer p_ds_grid_grp; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + p_ds_grid_grp(i) = static_cast(karg.p_ds_grid[i]) + + ds_batch_offset[i] + ds_n_offset[i]; + }); + + // Currently supporting one A and one B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // AScale struct (Empty) + using AScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = AScale{}; + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_grp, + karg.p_e_grid + e_batch_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args); } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // AScale struct (Empty) - using AScale = typename BlockwiseGemmPipe::Empty; - auto a_scale_struct = AScale{}; - - // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid_, - p_bs_grid_, - p_ds_grid_grp, - karg.p_e_grid + e_batch_offset + e_n_offset, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args); } };