diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bcce4ef9ca..3a67fe7602 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -7,6 +7,7 @@ #include +#include "ck/ck.hpp" #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/utility/env.hpp" @@ -100,17 +101,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( + TailNum, + decltype(epilogue_args)>( p_shared, gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, @@ -127,17 +131,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { if(gemm_kernel_args[group_id].HasMainKBlockLoop_) { - GridwiseGemm::template Run( + TailNum, + decltype(epilogue_args)>( p_shared, gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, @@ -152,17 +160,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } else { - GridwiseGemm::template Run( + TailNum, + decltype(epilogue_args)>( p_shared, gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index f662ff834f..bfb567d1e0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -7,6 +7,7 @@ #include #include +#include "ck/ck.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -28,6 +29,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#include "ck/utility/tuple.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -71,23 +73,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - GridwiseGemm::template Run, // Empty tuple CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + decltype(block_2_ctile_map_), ComputePtrOffsetOfBatch, + ComputePtrOffsetOfBatch, // placeholder 1, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - num_k_per_block, - karg, - epilogue_args); + false, + TailNum, + decltype(epilogue_args)>( + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>(), // placeholder + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map_, + compute_ptr_offset_of_batch, + ComputePtrOffsetOfBatch{}, // placeholder + num_k_per_block, + karg, + epilogue_args); #if defined(__gfx11__) } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index f9b2ff0596..053f0eb3ae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -7,6 +7,7 @@ #include #include +#include "ck/ck.hpp" #include "ck/utility/common_header.hpp" #include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -29,6 +30,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#include "ck/utility/tuple.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -71,23 +73,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - GridwiseGemm::template Run, // Empty tuple CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + decltype(block_2_ctile_map_), ComputePtrOffsetOfBatch, + ComputePtrOffsetOfBatch, // placeholder NumGroupsToMerge, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - num_k_per_block, - karg, - epilogue_args); + false, + TailNum, + decltype(epilogue_args)>( + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>(), // placeholder + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map_, + compute_ptr_offset_of_batch, + ComputePtrOffsetOfBatch{}, // placeholder + num_k_per_block, + karg, + epilogue_args); + #if defined(__gfx11__) } #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index b2ae092c27..2bce582f68 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -7,6 +7,7 @@ #include #include +#include "ck/ck.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -71,23 +72,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - GridwiseGemm::template Run, // Empty tuple CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + decltype(block_2_ctile_map_), ComputePtrOffsetOfBatch, + ComputePtrOffsetOfBatch, // placeholder 1, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - num_k_per_block, - karg, - epilogue_args); + false, + TailNum, + decltype(epilogue_args)>( + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>(), // placeholder + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map_, + compute_ptr_offset_of_batch, + ComputePtrOffsetOfBatch{}, // placeholder + num_k_per_block, + karg, + epilogue_args); #if defined(__gfx11__) } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index df252da8b4..8781a3c38a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -105,24 +105,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - GridwiseGemm::template Run(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - compute_ptr_offset_of_n, - num_k_per_block, - karg, - epilogue_args); + false, + TailNum, + decltype(epilogue_args)>( + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map_, + compute_ptr_offset_of_batch, + compute_ptr_offset_of_n, + num_k_per_block, + karg, + epilogue_args); + #if defined(__gfx11__) } #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index a34170df88..7818074b7f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -3,6 +3,8 @@ #pragma once +#include "ck/ck.hpp" +#include "ck/utility/array.hpp" #include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" @@ -844,313 +846,126 @@ struct GridwiseGemm_wmma_cshuffle_v3 return Block2CTileMap{M, N, 4}; } - // Run method for convolution for bwd_data (grid descriptors are passed as arguments, - // not generated internally) - template () function for all regimes (bwd_data, bwd_weight, fwd) + template __device__ static void Run(void* p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock_, + const CEGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ce_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMapExt& block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN& compute_ptr_offset_of_n, const index_t num_k_per_block, Argument& karg, EpilogueArgument& epilogue_args) { - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); - const index_t k_idx = - __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block); - // offset base pointer for each work-group + // Resolve the current regime at compile time: + constexpr bool is_bwd_data = (Regime == ConvRegime::BWD_DATA); + constexpr bool is_bwd_weight = (Regime == ConvRegime::BWD_WEIGHT); + constexpr bool is_fwd = (Regime == ConvRegime::FORWARD); + + // ======== Index ========= + const auto g_idx = [&]() -> index_t { + if constexpr(is_bwd_data || is_fwd) + { + return __builtin_amdgcn_readfirstlane(blockIdx.y); + } + else + { + return __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + } + }(); + + const index_t n_idx = + (is_bwd_data || is_fwd) ? __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch) : 0; + + // Using a lambda for better clang compliance than nested ternary operators + const auto k_idx = [&]() -> index_t { + if constexpr(is_bwd_data) + { + return __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * + num_k_per_block); + } + else if constexpr(is_bwd_weight) + { + return __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + } + else + { + return 0; + } + }(); + + // ======== Offset ======== + const long_index_t a_batch_offset = CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const long_index_t a_n_offset = - CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t b_n_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + (!CTranspose) ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) + : 0; - AsGridPointer p_as_grid_; - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType_ = remove_cvref_t>; - p_as_grid_(i) = - static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; - }); + // b_n_offset + const auto b_n_offset = [&]() -> long_index_t { + if constexpr(is_bwd_data) + { + return CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) + : 0; + } + else if constexpr(is_fwd) + { + return amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); + } + else + { + return 0; + } + }(); - BsGridPointer p_bs_grid_; - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = - static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; - }); - - DsGridPointer p_ds_grid_grp; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; }); - - // Currently supporting one A and one B - const auto as_grid_desc_ak0_m_ak1 = generate_tuple( - [&](auto i) { - ignore = i; - return a_grid_desc_ak0_m_ak1; - }, - Number{}); - - const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( - [&](auto i) { - ignore = i; - return b_grid_desc_bk0_n_bk1; - }, - Number{}); - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // AScale struct (Empty) - using AScale = typename BlockwiseGemmPipe::Empty; - auto a_scale_struct = AScale{}; - - // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid_, - p_bs_grid_, - p_ds_grid_grp, - karg.p_e_grid + e_batch_offset + e_n_offset, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args, - k_idx, - k_idx, - karg.KBatch); - } - - // Run method for convolution (grid descriptors are passed as arguments, - // not generated internally) - template - __device__ static void Run(void* p_shared, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block, - Argument& karg, - EpilogueArgument& epilogue_args) - { - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - AsGridPointer p_as_grid_; - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType_ = remove_cvref_t>; - p_as_grid_(i) = static_cast(karg.p_as_grid[i]) + a_batch_offset; - }); - - BsGridPointer p_bs_grid_; - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; - }); - - const auto ds_grid_desc_m_n = - MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, karg.MBlock, karg.NBlock); - - const auto as_grid_desc_ak0_m_ak1 = generate_tuple( - [&](auto i) { - ignore = i; - return a_grid_desc_ak0_m_ak1; - }, - Number{}); - - const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( - [&](auto i) { - ignore = i; - return b_grid_desc_bk0_n_bk1; - }, - Number{}); - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // Scale structs (Empty) - using Scale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = Scale{}; - auto a_scale_struct = Scale{}; - - const index_t num_k_block_per_scale = GetKBlockPerScale(); - - Base::template Run(p_as_grid_, - p_bs_grid_, - karg.p_ds_grid, - karg.p_e_grid + e_batch_offset, - p_shared, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_desc_mblock_mperblock_nblock_nperblock, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - block_m_id, - block_n_id, - num_k_block_per_scale, - a_scale_struct, - b_scale_struct, - epilogue_args, - k_idx, - k_idx, - karg.KBatch); - } - - // Run method for convolution fwd (grid descriptors are passed as arguments, - // not generated internally) - template - __device__ static void Run(void* p_shared, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN& compute_ptr_offset_of_n, - [[maybe_unused]] const index_t num_k_per_block, - Argument& karg, - EpilogueArgument& epilogue_args) - { - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); - // offset base pointer for each work-group - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t b_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); - const long_index_t e_n_offset = + const auto e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + // ======== Grid pointers ======== // + AsGridPointer p_as_grid_; static_for<0, NumATensor, 1>{}([&](auto i) { using ADataType_ = remove_cvref_t>; @@ -1167,12 +982,34 @@ struct GridwiseGemm_wmma_cshuffle_v3 DsGridPointer p_ds_grid_grp; static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType_ = remove_cvref_t>; - p_ds_grid_grp(i) = static_cast(karg.p_ds_grid[i]) + - ds_batch_offset[i] + ds_n_offset[i]; + if constexpr(is_bwd_data) + { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; + } + else if constexpr(is_fwd) + { + using DDataType_ = remove_cvref_t>; + p_ds_grid_grp(i) = static_cast(karg.p_ds_grid[i]) + + ds_batch_offset[i] + ds_n_offset[i]; + } }); - // Currently supporting one A and one B + // ======== Grid descriptors ======== // + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = [&]() { + if constexpr(is_bwd_weight) + { + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs); + return MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, karg.MBlock, karg.NBlock); + } + else + { + return ds_grid_desc_mblock_mperblock_nblock_nperblock_; + } + }(); + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( [&](auto i) { ignore = i; @@ -1187,51 +1024,62 @@ struct GridwiseGemm_wmma_cshuffle_v3 }, Number{}); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + // ======== Tiling ======== // const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(ce_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + ce_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } + // ======== Remaining Run() arguments ======== // + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - // AScale struct (Empty) - using AScale = typename BlockwiseGemmPipe::Empty; - auto a_scale_struct = AScale{}; - - // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + // Scale structs (Empty) + using Scale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = Scale{}; + auto b_scale_struct = Scale{}; const index_t num_k_block_per_scale = GetKBlockPerScale(); + // p_ds_grid_ + const auto p_ds_grid_ = (is_bwd_data || is_fwd) ? p_ds_grid_grp : karg.p_ds_grid; + + // p_e_grid_ + const auto p_e_grid_ = karg.p_e_grid + e_batch_offset + e_n_offset; + + // Final arguments + const index_t A_k_id = k_idx; + const index_t B_k_id = k_idx; + const index_t k_batch = (is_bwd_data || is_bwd_weight) ? karg.KBatch : 1; + + // ======= Call the Run() function ======== // + Base::template Run(p_as_grid_, p_bs_grid_, - p_ds_grid_grp, - karg.p_e_grid + e_batch_offset + e_n_offset, + p_ds_grid_, + p_e_grid_, p_shared, as_grid_desc_ak0_m_ak1, bs_grid_desc_bk0_n_bk1, ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, + ce_grid_desc_mblock_mperblock_nblock_nperblock, karg.a_element_op, karg.b_element_op, karg.cde_element_op, @@ -1240,7 +1088,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 num_k_block_per_scale, a_scale_struct, b_scale_struct, - epilogue_args); + epilogue_args, + A_k_id, + B_k_id, + k_batch); } };