diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 5f7e78fac2..e4a2908a53 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -989,6 +989,78 @@ struct UniversalGemmKernel } } + // Version of RunGemm using descriptors + template + CK_TILE_DEVICE static void RunGemmDesc(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n, + const std::array& as_desc, + const std::array& bs_desc, + const std::array& ds_desc, + const EGridDesc& e_desc) + { + // Create tensor views from descriptors (supports arbitrary stride patterns) + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(as_ptr[i]), as_desc[i]); + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(bs_ptr[i]), bs_desc[i]); + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(ds_ptr[i]), ds_desc[i]); + }, + number{}); + + auto e_tensor_view = + make_tensor_view(static_cast(e_ptr), e_desc); + + const auto& gemm_tensors_views_tuple = + make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. *