// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { struct FlatmmProblem { CK_TILE_HOST FlatmmProblem() = default; CK_TILE_HOST FlatmmProblem( index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) { } index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; index_t stride_C; }; template struct FlatmmScalePointer { using ScaleType = ScaleType_; static constexpr int GranularityMN = SharedGranularityMN; static constexpr int GranularityK = SharedGranularityK; const ScaleType* ptr; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {} CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_) : ptr(ptr_) { } CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const { FlatmmScalePointer ret; if constexpr(GranularityMN == 0) { ret.ptr = ptr + offset / GranularityK; } else { ret.ptr = ptr + offset / GranularityMN / GranularityK; } return ret; } CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete; }; template struct FlatmmScalePointer { using ScaleType = ScaleType_; static constexpr int GranularityMN = SharedGranularityMN; static constexpr int GranularityK = 0; static_assert(GranularityMN != 0); const ScaleType* ptr; index_t length; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {} CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_) : ptr(ptr_), length(length_) { } CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const { FlatmmScalePointer ret; if constexpr(GranularityMN == 1) { ret.ptr = ptr + offset; ret.length = length - offset; } else { ret.ptr = ptr + offset / GranularityMN; ret.length = length - offset / GranularityMN; } return ret; } CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const { // with additional oob check if constexpr(GranularityMN == 1) return i < length ? ptr[i] : 0; else return i / GranularityMN < length ? ptr[i / GranularityMN] : 0; } }; // shared granularityMN = -1 means no scale template struct FlatmmScalePointer<-1, 0, ScaleType_> { using ScaleType = ScaleType_; static constexpr int GranularityMN = -1; static constexpr int GranularityK = 0; const ScaleType* ptr = nullptr; CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default; CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {} CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {} CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const { return FlatmmScalePointer{}; } CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const { return 1; // alway return 1, it doesn't change the result } }; template struct BaseFlatmmHostArgs { CK_TILE_HOST BaseFlatmmHostArgs() = default; CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_, const void* b_ptr_, const std::array& ds_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array& stride_Ds_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), ds_ptr(ds_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_Ds(stride_Ds_), stride_E(stride_E_), k_batch(k_batch_) { } const void* a_ptr; const void* b_ptr; const std::array ds_ptr; union { void* e_ptr; void* c_ptr; }; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; const std::array stride_Ds; union { index_t stride_E; index_t stride_C; }; index_t k_batch; }; template , class ScaleN = FlatmmScalePointer<-1>, index_t NumDTensor = 0> struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<> { CK_TILE_HOST ScaleFlatmmHostArgs() = default; CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_, const void* b_shuffle_ptr_, const std::array& ds_ptr_, void* c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array& stride_Ds_, index_t stride_C_, ScaleM scale_m_ = nullptr, ScaleN scale_n_ = nullptr) : BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_), scale_m(scale_m_), scale_n(scale_n_) { } ScaleM scale_m = nullptr; ScaleN scale_n = nullptr; }; template using FlatmmHostArgs = ScaleFlatmmHostArgs, FlatmmScalePointer<-1>, NumberTensor>; template struct FlatmmKernelArgs { const void* a_ptr; // const void* b_shuffle_ptr; const void* b_ptr; const std::array ds_ptr; void* e_ptr; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; std::array stride_Ds; index_t stride_E; index_t k_batch; ScaleM scale_m_ptr = nullptr; ScaleN scale_n_ptr = nullptr; }; template struct FlatmmKernel { using TilePartitioner = remove_cvref_t; using FlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using ELayout = remove_cvref_t; using DsLayout = remove_cvref_t; using DsDataType = remove_cvref_t; static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize; static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; static constexpr index_t NumDTensor = DsDataType::size(); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto I3 = number<3>(); static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); // using KernelArgs = FlatmmKernelArgs; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "gemm", gemm_prec_str(), FlatmmPipeline::GetName()); // clang-format on } CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { assert(!UsePersistentKernel); return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } template CK_TILE_HOST static constexpr auto GridSize(const FlatmmKernelArgs& kargs) { if constexpr(UsePersistentKernel) { hipDeviceProp_t prop; int deviceId = 0; // default device constexpr int block_size = FlatmmKernel::BlockSize().x; int dync_smem_size = 0; int maxActiveBlocksPerCU = 0; [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId); e = hipOccupancyMaxActiveBlocksPerMultiprocessor( &maxActiveBlocksPerCU, reinterpret_cast( kentry<1, FlatmmKernel, FlatmmKernelArgs>), block_size, dync_smem_size); const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU // << ", persistent_block_size: " << persistent_block_size // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl; assert(kargs.k_batch == 1); return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch); } else { return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch); } } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } template CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const ScaleFlatmmHostArgs& hostArgs) { return {hostArgs.a_ptr, hostArgs.b_ptr, hostArgs.ds_ptr, hostArgs.e_ptr, hostArgs.M, hostArgs.N, hostArgs.K, hostArgs.stride_A, hostArgs.stride_B, hostArgs.stride_Ds, hostArgs.stride_E, hostArgs.k_batch, hostArgs.scale_m, hostArgs.scale_n}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize() { return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize() { return FlatmmPipeline::GetSmemSize(); } struct SplitKBatchOffset { template __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{}); constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; if constexpr(std::is_same_v) { a_k_split_offset = k_id * KRead; } else if constexpr(std::is_same_v) { a_k_split_offset = k_id * KRead * kargs.stride_A; } if constexpr(std::is_same_v) { b_k_split_offset = k_id * KRead * kargs.stride_B * N1; } else if constexpr(std::is_same_v) { b_k_split_offset = k_id * KRead * N1; } if(k_id < static_cast(kargs.k_batch - 1)) { splitted_k = KRead; } else { splitted_k = kargs.K - KRead * (kargs.k_batch - 1); } } index_t a_k_split_offset; index_t b_k_split_offset; index_t splitted_k; }; template CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) { std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; return false; } } if constexpr(UsePersistentKernel) { if(kargs.k_batch != 1) { std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl; return false; } } if constexpr(std::is_same_v) { if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) { std::cerr << "Can't support K that is not a multiple of KPerBlock" " without padding!" << std::endl; return false; } if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) { std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) { std::cerr << "Can't support M that is not a multiple of MPerBlock" " without padding!" << std::endl; return false; } if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) { std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; return false; } } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) { std::cerr << "Can't support N that is not a multiple of NPerBlock" " without padding!" << std::endl; return false; } if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) { std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; return false; } } else { if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) { std::cerr << "Can't support K that is not a multiple of KPerBlock" " without padding!" << std::endl; return false; } if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) { std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; return false; } } bool DTesnorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; if(std::is_same_v == false) { DTesnorIsValid = false; } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) { CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " "NPerBlock without padding!"); DTesnorIsValid = false; } if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) { CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); DTesnorIsValid = false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) { CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " "MPerBlock without padding!"); DTesnorIsValid = false; } if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) { CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); DTesnorIsValid = false; } } }); if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) { std::cerr << "Can't support N that is not a multiple of NPerBlock" " without padding!" << std::endl; return false; } if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) { std::cerr << "Can't support M that is not a multiple of MPerBlock" " without padding!" << std::endl; return false; } if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; } } return DTesnorIsValid; } template CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, const KernelArgs& kargs, const index_t k_size, const index_t block_idx_m) { // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( a_ptr, make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); // Step 2: Create padded view const auto& a_pad_view = [&]() { if constexpr(std::is_same_v) { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); // Step 3: Create tile window if constexpr(std::is_same_v) { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {block_idx_m, 0}); } else { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {0, block_idx_m}); } } template CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, const KernelArgs& kargs, const index_t block_idx_n) { // Step 1: Create tensor view index_t kFlatK = FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); index_t kFlatN = kargs.N * kargs.K / kFlatK; const auto& b_flat_tensor_view = make_naive_tensor_view( b_flat_ptr, make_tuple(kFlatN, kFlatK), make_tuple(kFlatK, 1), number{}, number<1>{}); // Step 2: No padding needed for b_flat // Step 3: Create tile window return make_tile_window( b_flat_tensor_view, make_tuple(number{}, number{}), {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); } template CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, const KernelArgs& kargs, const index_t block_idx_m, const index_t block_idx_n) { // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; using DDataType_ = remove_cvref_t>; if constexpr(std::is_same_v) { return make_naive_tensor_view( static_cast(ds_ptr[i]), make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_Ds[i], 1), number{}, number<1>{}); } else { return make_naive_tensor_view( static_cast(ds_ptr[i]), make_tuple(kargs.N, kargs.M), make_tuple(kargs.stride_Ds[i], 1), number{}, number<1>{}); } }, number{}); // Step 2: Create padded views const auto& ds_pad_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, number{}), sequence{}); } }, number{}); // Step 3: Create tile windows return generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), {block_idx_m, block_idx_n}); } else { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), {block_idx_n, block_idx_m}); } }, number{}); } template CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, const KernelArgs& kargs, const index_t block_idx_m, const index_t block_idx_n) { // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( e_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_E, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( e_ptr, make_tuple(kargs.N, kargs.M), make_tuple(kargs.stride_E, 1), number<1>{}, number<1>{}); } }(); // Step 2: Create padded view const auto& e_pad_view = [&]() { if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); // Step 3: Create tile window return make_tile_window( e_pad_view, make_tuple(number{}, number{}), {block_idx_m, block_idx_n}); } template CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m) { constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN; constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK; static_assert(GM != -1, "MakeScaleMWindow should only be instantiated when scale is enabled"); // per-tensor (GM==0) -> Mdim = 1, stride 0 const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM); const index_t m_stride = (GM == 0) ? 0 : 1; const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); const index_t k_stride = 0; // your original code keeps K stride 0 const auto scale_m_view = make_naive_tensor_view( kargs.scale_m_ptr.ptr, make_tuple(m_dim, k_dim), make_tuple(m_stride, k_stride), number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, number<1>{}); // Window extents: if GM==0, we still just broadcast from [0,*] return make_tile_window(scale_m_view, make_tuple(number{}, number < (GK == 0) ? TilePartitioner::NPerBlock : TilePartitioner::KPerBlock > {}), {block_idx_m, 0}); } template CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_n) { constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN; constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK; static_assert(GN != -1, "MakeScaleNWindow should only be instantiated when scale is enabled"); // per-tensor (GN==0) -> Ndim = 1, stride 0 const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN); const index_t n_stride = (GN == 0) ? 0 : 1; const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); const index_t k_stride = 0; const auto scale_n_view = make_naive_tensor_view( kargs.scale_n_ptr.ptr, make_tuple(k_dim, n_dim), make_tuple(k_stride, n_stride), number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, number<1>{}); return make_tile_window(scale_n_view, make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock : TilePartitioner::KPerBlock > {}, number{}), {0, block_idx_n}); } template CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, const BDataType* b_flat_ptr, const std::array& ds_ptr, EDataType* e_ptr, void* smem_ptr_ping, void* smem_ptr_pong, const FlatmmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { // Create block windows using specialized methods const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); if(kargs.k_batch == 1) { auto e_block_window = MakeEBlockWindow( e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{} .template operator()(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping, scale_m_window, scale_n_window); } else { auto e_block_window = MakeEBlockWindow( e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{} .template operator()(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping, scale_m_window, scale_n_window); } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { if(kargs.k_batch == 1) { auto e_block_window = MakeEBlockWindow( e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{} .template operator()( e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); } else { auto e_block_window = MakeEBlockWindow( e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{} .template operator()( e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); } } } template CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs, int partition_idx = blockIdx.x) const { int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); do { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); // options const ADataType* a_ptr = static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, e_ptr, smem_ptr_ping, smem_ptr_pong, kargs, splitk_batch_offset, i_m, i_n); } partition_idx += gridDim.x; } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); } }; } // namespace ck_tile