From 0b49f75e9e2061ccff7a46e871c35ceb283d29c0 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 26 Nov 2024 08:45:14 +0100 Subject: [PATCH] CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676) * Block universal gemm. * Universal block gemm with interwave scheduler - draft. * Refactoring * Move a/b_warp_tiles into BlockGemmImpl * set BlockGemmImpl as a class member * Change tile size for more suitable to memory bound cases. * Introduce kKPerThread to WarpGemm * Add documentation comment. * Fix Interwave scheduler block gemm. * Add compute/memory friendly tile configuration. * Clean * New tile configurations in gemm mem example. * Add more static checks and fix loop order in block gemm. * Add more static checks and use warp gemm mfma dispatcher. * Add default scheduler block gemm. * Remove logging in example. [ROCm/composable_kernel commit: b6bcd76d881421af2f04246b1e4bbac45b7ce3b9] --- example/01_gemm/run_gemm_example_v2.inc | 2 +- example/ck_tile/03_gemm/gemm_mem_pipeline.cpp | 33 +- example/ck_tile/03_gemm/run_gemm_example.inc | 22 +- include/ck_tile/ops/gemm.hpp | 1 + .../block/block_universal_gemm_as_bs_cr.hpp | 661 ++++++++++++++++++ .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 2 + ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 38 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 55 +- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 7 +- 11 files changed, 779 insertions(+), 56 deletions(-) create mode 100644 include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 71524fdecf..5b6969f1d9 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -261,7 +261,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp index ff9d8bad32..97d150412d 100644 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -17,9 +17,24 @@ template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { - // ToDo: This will be modified by the codegen code later. +#if 1 + // Memory friendly for Interwave scheduler constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + +#else + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; @@ -28,12 +43,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = true; - constexpr bool kPadN = true; - constexpr bool kPadK = true; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; @@ -174,8 +189,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { std::ostringstream err; err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__; + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 8db131738b..5199c1e3ef 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::string op_name{"Gemm{MemBoundPipeline}"}; - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc, f_host_tensor_descriptor(M, N, stride_C, CLayout{})); // TODO: add different init types - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); @@ -202,14 +199,15 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } + // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not + // work. else if(a_layout == "C" && b_layout == "C") + // { + // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + // } + // else if(a_layout == "C" && b_layout == "R") + // { + // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ac74782a3a..9a033ee2de 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -22,6 +22,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp new file mode 100644 index 0000000000..5f98a7a0ba --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -0,0 +1,661 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockUniversalGemmAsBsCr +{ + private: + // TODO: This should be in Policy - UniversalGemmPolicyBase ? + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}), + "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}), + "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}), + "Error! WarpGemm's M is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}), + "Error! WarpGemm's N is not consisten with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM; + static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN; + static constexpr index_t KPerBlockPerIter = WarpGemm::kK; + + using AWarpTileDistr = remove_cvref_t; + using BWarpTileDistr = remove_cvref_t; + + using AWarpTile = + remove_cvref_t(AWarpTileDistr{}))>; + using BWarpTile = + remove_cvref_t(BWarpTileDistr{}))>; + + // TODO: Should we have two policies? Interwave & Intrawave ?? + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack; + static constexpr index_t KRepeat = KPerThread / KPack; + }; + + public: + using Traits = GemmTraits_; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using WarpGemm = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + + static constexpr auto Scheduler = Traits::Scheduler; + + private: + template + struct BlockGemmImpl + { + }; + + template + struct BlockGemmImpl + { + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + static_assert( + std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + static_assert(std::is_same_v && + std::is_same_v, + "The ADataType and BDataType as defined in " + "traits should be the same as correspoinding block window data type!"); + + static_assert( + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + "MPerBlock, NPerBlock, KPerBlock defined in " + " BlockGemmShape are different from A/B block smem windows apropriate dims!"); + + const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + + // TODO: refactor warp_window tile type to class member as it should be + // compile-time known information. + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + + multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + + using AWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == + AWarpWindow::get_num_of_dimension(), + "AWarpWindow number of dimensions must be equal to " + "AWarpTile number of dimensions!"); + static_assert(GemmTraits::AWarpTile::get_lengths() == + AWarpWindow{}.get_window_lengths(), + "AWarpWindow lengths must be equal to AWarpTile lengths!"); + + statically_indexed_array< + statically_indexed_array, + GemmTraits::MIterPerWarp> + a_warp_windows; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + + multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + + using BWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == + BWarpWindow::get_num_of_dimension(), + "BWarpWindow number of dimensions must be equal to " + "BWarpTile number of dimensions!"); + static_assert(GemmTraits::BWarpTile::get_lengths() == + BWarpWindow{}.get_window_lengths(), + "BWarpWindow lengths must be equal to BWarpTile lengths!"); + + statically_indexed_array< + statically_indexed_array, + GemmTraits::NIterPerWarp> + b_warp_windows; + + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + // TODO: I don't have to move 0,0 window! + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * GemmTraits::MPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * GemmTraits::NPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; + using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor- + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + }; + + template + struct BlockGemmImpl + { + statically_indexed_array< + statically_indexed_array, + GemmTraits::MIterPerWarp> + a_warp_tiles_; + + statically_indexed_array< + statically_indexed_array, + GemmTraits::NIterPerWarp> + b_warp_tiles_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + static_assert( + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + "MPerBlock, NPerBlock, KPerBlock defined in " + " BlockGemmShape are different from A/B block smem windows apropriate dims!"); + + static_assert(std::is_same_v && + std::is_same_v, + "The ADataType and BDataType as defined in " + "traits should be the same as correspoinding block window data type!"); + + const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + + // TODO: refactor warp_window tile type to class member as it should be + // compile-time known information. + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + + multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + + using AWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == + AWarpWindow::get_num_of_dimension(), + "AWarpWindow number of dimensions must be equal to " + "AWarpTile number of dimensions!"); + static_assert(GemmTraits::AWarpTile::get_lengths() == + AWarpWindow{}.get_window_lengths(), + "AWarpWindow lengths must be equal to AWarpTile lengths!"); + + statically_indexed_array< + statically_indexed_array, + GemmTraits::MIterPerWarp> + a_warp_windows; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + + multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + + using BWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == + BWarpWindow::get_num_of_dimension(), + "BWarpWindow number of dimensions must be equal to " + "BWarpTile number of dimensions!"); + static_assert(GemmTraits::BWarpTile::get_lengths() == + BWarpWindow{}.get_window_lengths(), + "BWarpWindow lengths must be equal to BWarpTile lengths!"); + + statically_indexed_array< + statically_indexed_array, + GemmTraits::NIterPerWarp> + b_warp_windows; + + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + // TODO: I don't have to move 0,0 window! + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * GemmTraits::MPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * GemmTraits::NPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + }); + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + }); + }); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ASmemBlockWindow& a_block_window, + [[maybe_unused]] const BSmemBlockWindow& b_block_window) + { + static_assert( + std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + + using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; + using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor- + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + typename GemmTraits::WarpGemm{}(c_warp_tensor, + a_warp_tiles_[mIter][kIter], + b_warp_tiles_[nIter][kIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + }; + + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack); + // TODO: do we really need this?? Are there any cases when this would be >=1 ?? + // Would we need InterWaveSchedulingMacClusters > 1 ??? + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack; + + statically_indexed_array< + statically_indexed_array, + GemmTraits::MIterPerWarp> + a_warp_tiles_; + + statically_indexed_array< + statically_indexed_array, + GemmTraits::NIterPerWarp> + b_warp_tiles_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + static_assert( + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + "MPerBlock, NPerBlock, KPerBlock defined in " + " BlockGemmShape are different from A/B block smem windows apropriate dims!"); + + static_assert(std::is_same_v && + std::is_same_v, + "The ADataType and BDataType as defined in " + "traits should be the same as correspoinding block window data type!"); + + const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + + // TODO: refactor warp_window tile type to class member as it should be + // compile-time known information. + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + + multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + + using AWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == + AWarpWindow::get_num_of_dimension(), + "AWarpWindow number of dimensions must be equal to " + "AWarpTile number of dimensions!"); + static_assert(GemmTraits::AWarpTile::get_lengths() == + AWarpWindow{}.get_window_lengths(), + "AWarpWindow lengths must be equal to AWarpTile lengths!"); + + statically_indexed_array, + GemmTraits::MIterPerWarp> + a_warp_windows; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + + multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop}, + make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + + using BWarpWindow = remove_cvref_t; + + static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == + BWarpWindow::get_num_of_dimension(), + "BWarpWindow number of dimensions must be equal to " + "BWarpTile number of dimensions!"); + static_assert(GemmTraits::BWarpTile::get_lengths() == + BWarpWindow{}.get_window_lengths(), + "BWarpWindow lengths must be equal to BWarpTile lengths!"); + + statically_indexed_array, + GemmTraits::NIterPerWarp> + b_warp_windows; + + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * GemmTraits::MPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * GemmTraits::NPerBlockPerIter, + kIter * GemmTraits::KPerBlockPerIter}); + }); + }); + + // TODO check if a_warp_tiles has same desc as a_warp_window + static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + }); + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + }); + }); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + static_assert( + std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + + using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; + using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KRepeat, 1>{}([&](auto kIter) { + LocalPrefetch(a_block_window, b_block_window); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(kIter.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { + static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor- + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = + c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(kIter.value == KRepeat - 1 && + kInnerIter.value == KInnerLoopIter - 1 && + mIter.value == GemmTraits::MIterPerWarp - 1 && + nIter.value == GemmTraits::NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + // warp GEMM + typename GemmTraits::WarpGemm{}(c_warp_tensor, + a_warp_tiles_[mIter][kInnerIter], + b_warp_tiles_[nIter][kInnerIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + if constexpr(kInnerIter.value == 0 && mIter.value == 0 && + nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; + + public: + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + auto c_block_tensor = MakeCBlockTile(); + block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window); + return c_block_tensor; + } + + private: + BlockGemmImpl block_gemm_impl_{}; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 85c5c58056..4634e9dcb9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_lds_block, make_tuple(number{}, number{}), {0, 0}); // Block GEMM - constexpr auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); @@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - // block_gemm.LocalPrefetch(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); @@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - // block_gemm.LocalPrefetch(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); @@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }); block_sync_lds(); - // block_gemm.LocalPrefetch(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - // block_gemm.LocalPrefetch(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } else if constexpr(TailNum == TailNumber::Two) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 5e93ca21c0..6f51e6b8a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -11,6 +11,7 @@ namespace ck_tile { enum struct GemmPipelineScheduler { + Default, Intrawave, Interwave, }; @@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch { switch(s) { + case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break; case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break; case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break; default: os << ""; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index c765b3ce9d..b475ebb7bd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + // TODO: this 8 is AK1! should be a policy parameter! constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number<8>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), @@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); return make_static_tile_distribution( tile_distribution_encoding, @@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { constexpr index_t M0 = BlockSize / get_warp_size(); constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, @@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); constexpr index_t N0 = NPerBlock / (N2 * N1); + static_assert(N0 * N1 * N2 == NPerBlock, + "Incorrect N0, N1, N2 configuration! " + "N0, N1, N2 must cover whole NPerBlock!"); return make_static_tile_distribution( tile_distribution_encoding, @@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { constexpr index_t N0 = BlockSize / get_warp_size(); constexpr index_t N1 = NPerBlock / (N2 * N0); - + static_assert(N0 * N1 * N2 == NPerBlock, + "Incorrect N0, N1, N2 configuration! " + "N0, N1, N2 must cover whole NPerBlock!"); return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, @@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; + constexpr bool TransposeC = false; + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; - return BlockGemmASmemBSmemCRegV1{}; + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockUniversalGemmAsBsCr{}; } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 3c43790bd6..bf51577aeb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -33,6 +33,8 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = GemmTraits::kPadN; static constexpr bool kPadK = GemmTraits::kPadK; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 0a8d2dfbe3..a9e466a796 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,9 +21,10 @@ struct WarpGemmAtrributeMfma using BVecType = typename Impl::BVecType; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kM; - static constexpr index_t kN = Impl::kN; - static constexpr index_t kK = Impl::kK; + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -86,9 +87,10 @@ struct WarpGemmAtrributeMfmaIterateK ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kM; - static constexpr index_t kN = Impl::kN; - static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } @@ -197,9 +199,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution using BVecType = typename Impl::AVecType; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK; + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -260,9 +263,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB using BVecType = typename Impl::AVecType; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK; + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -330,9 +334,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } @@ -444,10 +449,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK * kKIter; - static constexpr index_t SFactor = SFactor_; // group how many CM1 together + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } @@ -583,10 +589,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kM; - static constexpr index_t kN = Impl::kN; - static constexpr index_t kK = Impl::kK * kKIter; - static constexpr index_t SFactor = SFactor_; // group how many CM1 together + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index 182d023a00..f9d50ed35e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,11 @@ struct WarpGemmImpl static constexpr index_t kM = WarpGemmAttribute::kM; static constexpr index_t kN = WarpGemmAttribute::kN; static constexpr index_t kK = WarpGemmAttribute::kK; + /// @brief The number of elements in K dimension processed by single thread in wavefront. + /// + /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). + /// In such situation this value reflects this fact. + static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread; using ADataType = typename WarpGemmAttribute::ADataType; using BDataType = typename WarpGemmAttribute::BDataType;