From 16920dee0fff2920229ea8bef8ce8578dc756d03 Mon Sep 17 00:00:00 2001 From: kiefer Date: Wed, 20 Aug 2025 08:56:53 +0000 Subject: [PATCH] Add support for fwd conv in gridwise implementation. Identical to run function for bwd data. --- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 125 ++++++++++++++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 37 ++++-- 2 files changed, 152 insertions(+), 10 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index d37eebaed2..3eb57ccda3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -663,6 +663,131 @@ struct GridwiseGemm_wmma_cshuffle_v3 karg.b_element_op, karg.cde_element_op); } + + // Run method for convolution (grid descriptors are passed as arguments, + // not generated internally) + template + __device__ static void Run(void* p_shared, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t num_k_per_block, + Argument& karg) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); + const index_t k_idx = + __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block); + + // offset base pointer for each work-group + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; + }); + + DsGridPointer p_ds_grid_grp; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; }); + + // Currently supporting one A and one B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_grp, + karg.p_e_grid + e_batch_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct, + karg.KBatch, + k_idx); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index c8407a08ca..c39f9b22fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -162,11 +162,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + // Calculate grid size taking into account splitk (KBatch) + // 2D grid (x,z) __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); } + // Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch) + // 3D grid (x,y,z) + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + __host__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -594,8 +603,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } template - __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + __device__ __host__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n, + index_t MBlock, + index_t NBlock) { return generate_tuple( [&](auto i) { @@ -918,8 +929,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPack>())>; template - __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) + __host__ __device__ static constexpr auto + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n, + index_t MBlock, + index_t NBlock) { const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( de_grid_desc_m_n, @@ -1180,6 +1193,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + // Note: arguments k_batch and k_id should be set if splitk is used + // with implicit gemm (no pointer shift but shift using tensor descriptors) template 1) { const auto idx_as_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + [&](auto) { return make_multi_index(k_id, m_block_data_idx_on_grid, 0); }, Number{}); return ThreadGroupTensorSliceTransfer_v7r2< @@ -1307,7 +1324,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base true, BlockwiseGemmPipe::GlobalBufferNum>( as_grid_desc_ak0_m_ak1[I0], - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1323,7 +1340,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base if constexpr(NumBTensor > 1) { const auto idx_bs_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + [&](auto) { return make_multi_index(k_id, n_block_data_idx_on_grid, 0); }, Number{}); return ThreadGroupTensorSliceTransfer_v7r2< @@ -1377,7 +1394,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base true, BlockwiseGemmPipe::GlobalBufferNum>( bs_grid_desc_bk0_n_bk1[I0], - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1411,7 +1428,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / - KPerBlock); + (KPerBlock * k_batch)); blockwise_gemm_pipeline.template Run( get_first_element_workaround(as_grid_desc_ak0_m_ak1),