mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Joye/revise wp pipeline (#3493)
* [CK_TILE] unify double and single lds implementation (#108) Unify LDS buffer management API for single and double buffering modes This change consolidates the Local Data Store (LDS) buffer management by: Merging single and double LDS buffer APIs into a unified interface Implementing ping-pong address calculation in pipeline when double LDS is enabled Computing pong buffer addresses dynamically using base address offsets --------- Co-authored-by: joye <joye@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update wp_pipeline * fix a c++17 issue * update for ci errors * fix ci issues * include a header to fix ci errors * fix some rebase issues * update with rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
|
||||
212
include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp
Normal file
212
include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on register
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockWeightPreshuffleASmemBRegCReg
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
statically_indexed_array<AWarpTensor, m_preload> preloaded_a_warp_tensor;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<1, MWarp>, sequence<1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <typename SmemBlockWindow>
|
||||
CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const
|
||||
{
|
||||
constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
|
||||
// create MIterPerWarp × KIterPerWarp window
|
||||
return generate_tuple(
|
||||
[&](auto kIter) {
|
||||
return generate_tuple(
|
||||
[&](auto mIter) {
|
||||
return make_tile_window(
|
||||
get_slice_tile(
|
||||
a_block_window,
|
||||
sequence<mIter * MPerBlockPerIter, kIter * KPerBlockPerIter>{},
|
||||
sequence<(mIter + 1) * MPerBlockPerIter,
|
||||
(kIter + 1) * KPerBlockPerIter>{}),
|
||||
a_load_dstr);
|
||||
},
|
||||
number<MIterPerWarp>{});
|
||||
},
|
||||
number<KIterPerWarp>{});
|
||||
}
|
||||
|
||||
template <typename ALoadWindows>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows)
|
||||
{
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
|
||||
load_tile(preloaded_a_warp_tensor(loadIter),
|
||||
a_load_windows[number<kIter>{}][number<mIter>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ALoadWindows,
|
||||
typename BFlatBlockTensor,
|
||||
typename BFlatDistribution>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ALoadWindows& a_load_windows,
|
||||
BFlatBlockTensor& b_block_tensor,
|
||||
const BFlatDistribution&)
|
||||
{
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
|
||||
constexpr auto b_block_y_lengths =
|
||||
to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto b_block_y_index_zeros =
|
||||
uniform_sequence_gen_t<BFlatDistribution::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{},
|
||||
typename sequence_split<decltype(b_block_y_index_zeros),
|
||||
2>::right_type{}),
|
||||
merge_sequences(
|
||||
sequence<1, 1>{},
|
||||
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(
|
||||
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
|
||||
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -303,24 +303,15 @@ struct GroupedGemmKernel
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// TO DO:
|
||||
// Can we simplify this branching logic?
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
kargs.ds_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
RunGemmWithPipelineSelection2LDS(
|
||||
a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else // SingleSmemBuffer
|
||||
{
|
||||
@@ -331,7 +322,7 @@ struct GroupedGemmKernel
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -343,7 +334,7 @@ struct GroupedGemmKernel
|
||||
{b_ptr},
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -425,9 +416,7 @@ struct GroupedGemmKernel
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
@@ -439,8 +428,7 @@ struct GroupedGemmKernel
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor_>& ds_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
void* __restrict__ smem_ptr,
|
||||
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -460,8 +448,8 @@ struct GroupedGemmKernel
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if(kargs.k_batch == 1)
|
||||
@@ -469,7 +457,7 @@ struct GroupedGemmKernel
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -477,7 +465,7 @@ struct GroupedGemmKernel
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -978,7 +978,7 @@ struct UniversalGemmKernel
|
||||
* @param bs_ptr input Bs pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
@@ -990,7 +990,7 @@ struct UniversalGemmKernel
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -1008,7 +1008,7 @@ struct UniversalGemmKernel
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
// Run Epilogue Pipeline
|
||||
@@ -1016,77 +1016,63 @@ struct UniversalGemmKernel
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param as_ptr input As pointer
|
||||
* @param bs_ptr input Bs pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
CK_TILE_DEVICE static auto
|
||||
GetTileCoordinates(const KernelArgs& kargs) -> tuple<index_t, index_t>
|
||||
{
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
index_t iM, iN;
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
// Regular launch: use 1D block indexing
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
iM = tile_m;
|
||||
iN = tile_n;
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
BElementWise{},
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if(kargs.k_batch == 1)
|
||||
return make_tuple(i_m, i_n);
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
CK_TILE_DEVICE static auto GetBlockId() -> index_t
|
||||
{
|
||||
// For 1D regular launch
|
||||
return amd_wave_read_first_lane(get_block_id());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto GetGridSize() -> index_t
|
||||
{
|
||||
// For 1D regular launch
|
||||
return amd_wave_read_first_lane(get_grid_size());
|
||||
}
|
||||
|
||||
// Helper to get total number of tiles, handling both dim3 and index_t return types
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
|
||||
{
|
||||
auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
|
||||
|
||||
using GridSizeType = decltype(grid_size);
|
||||
|
||||
if constexpr(std::is_same_v<GridSizeType, dim3>)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
// GridSize returns dim3: compute total tiles as x * y * z
|
||||
return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
// GridSize returns scalar (index_t): use directly
|
||||
return amd_wave_read_first_lane(grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1123,36 +1109,12 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
constexpr auto scheduler_type =
|
||||
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(
|
||||
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
// Persistent kernel entry point
|
||||
@@ -1199,34 +1161,19 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
if(block_id >= num_work)
|
||||
|
||||
@@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
|
||||
template <typename SrcDataType = void,
|
||||
typename DstDataType = void,
|
||||
index_t UnaryOpSize = 8,
|
||||
typename DstBlockTile,
|
||||
typename SrcTileWindow,
|
||||
typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
load_tile(dst_block_tile, dram_tile_window);
|
||||
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
@@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
template <typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_lds_load_tile_distr = []() {
|
||||
@@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase
|
||||
else
|
||||
return ALdsLoadTileDistr{};
|
||||
}();
|
||||
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <
|
||||
typename ADramBlockWindowTmp,
|
||||
typename ALdsTensorView,
|
||||
typename ALdsLoadTileDistr,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ALdsTensorView>::value, bool>* = nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr& a_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows
|
||||
auto [a_copy_lds_window, a_lds_gemm_window] =
|
||||
MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(a_copy_dram_window),
|
||||
std::move(a_copy_lds_window),
|
||||
std::move(a_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&,
|
||||
// Unified GetAWindows that supports 1, 2, or 3 LDS buffers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename ALdsTensorViewsTuple,
|
||||
typename ALdsLoadTileDistr,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ALdsTensorViewsTuple>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorViewsTuple& a_lds_block_views_tuple,
|
||||
const ALdsLoadTileDistr& a_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// TODO: Do we really need those two tile windows???
|
||||
// They're exactly same...
|
||||
// B LDS tile window for store
|
||||
// Create LDS windows for each buffer
|
||||
constexpr index_t num_buffers = ALdsTensorViewsTuple::size();
|
||||
auto a_lds_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr);
|
||||
},
|
||||
number<num_buffers>{});
|
||||
|
||||
// Return: (dram_window, lds_windows_tuple)
|
||||
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
|
||||
return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows));
|
||||
}
|
||||
|
||||
template <typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&) const
|
||||
{
|
||||
auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
using BLdsDataType =
|
||||
@@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase
|
||||
else
|
||||
return BLdsLoadTileDistr{};
|
||||
}();
|
||||
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BLdsTensorView,
|
||||
typename BLdsLoadTileDistr,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, BLdsTensorView>::value, bool>* = nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr& b_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows
|
||||
auto [b_copy_lds_window, b_lds_gemm_window] =
|
||||
MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(b_copy_dram_window),
|
||||
std::move(b_copy_lds_window),
|
||||
std::move(b_lds_gemm_window));
|
||||
}
|
||||
|
||||
// Unified GetBWindows that supports 1, 2, or 3 LDS buffers
|
||||
template <typename BDramBlockWindowTmp,
|
||||
typename BLdsTensorViewsTuple,
|
||||
typename BLdsLoadTileDistr,
|
||||
typename std::enable_if_t<is_detected<is_tuple, BLdsTensorViewsTuple>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorViewsTuple& b_lds_block_views_tuple,
|
||||
const BLdsLoadTileDistr& b_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows for each buffer
|
||||
constexpr index_t num_buffers = BLdsTensorViewsTuple::size();
|
||||
auto b_lds_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr);
|
||||
},
|
||||
number<num_buffers>{});
|
||||
|
||||
// Return: (dram_window, lds_windows_tuple)
|
||||
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
|
||||
return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
@@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
return 2 * smem_size;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
@@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
// TODO support multi-ABD
|
||||
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
|
||||
@@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block1, b_lds_block1] =
|
||||
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
// set up LDS tile shapes
|
||||
constexpr auto a_lds_shape = []() {
|
||||
@@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
@@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
@@ -572,8 +572,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
|
||||
@@ -172,6 +172,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
return 2 * smem_size;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
@@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
@@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// global read 0
|
||||
|
||||
////////////// LDS desc, window & register /////////////////
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block1, b_lds_block1] =
|
||||
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v())
|
||||
@@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
@@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
@@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
@@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
@@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
@@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
@@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -789,14 +785,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -809,16 +803,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t kNPerBlock = TileShape::kN;
|
||||
constexpr index_t kKPerBlock = TileShape::kK;
|
||||
constexpr index_t NIterPerWarp =
|
||||
kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
@@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
constexpr index_t KRepeat = KIterPerWarp;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
constexpr index_t NRepeat = NIterPerWarp;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
@@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
sequence<1, 2, 1, 2>,
|
||||
sequence<0, 0, 3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
|
||||
return BlockWeightPreshuffleASmemBRegCReg<Problem, BlockWeightPreshufflePolicy>{};
|
||||
}
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
@@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
|
||||
using WG_ = typename BlockGemm::WG;
|
||||
using WG_ = typename BlockGemm::WarpGemm;
|
||||
|
||||
constexpr bool TransposeC = Problem::TransposeC;
|
||||
using CLayout = typename Problem::CLayout;
|
||||
|
||||
@@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
if(has_hot_loop)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else // Even tail number
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
else // Even tail number
|
||||
else
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else // Even tail number
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -52,7 +67,8 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, PipelinePolicy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
@@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
|
||||
static constexpr auto config =
|
||||
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
@@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
@@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t MWarp = BlockWarps::at(I0);
|
||||
static constexpr index_t NWarp = BlockWarps::at(I1);
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
static constexpr index_t WarpTileM = WarpTile::at(I0);
|
||||
static constexpr index_t WarpTileN = WarpTile::at(I1);
|
||||
static constexpr index_t WarpTileK = WarpTile::at(I2);
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
@@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
#else
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg =
|
||||
max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
|
||||
static constexpr index_t dsread_per_wg = max(
|
||||
index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
#endif
|
||||
static constexpr index_t dsread_num_perK =
|
||||
WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
|
||||
static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) *
|
||||
MIterPerWarp / WaveSize / Problem::VectorLoadSize;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
@@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', WG::kM, WG::kN, WG::kK),
|
||||
concat('x', WarpTileM, WarpTileN, WarpTileK),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
|
||||
@@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
return DoubleSmemBuffer ? 2 * smem_size : smem_size;
|
||||
}
|
||||
|
||||
// dsread_perM: how many LDS reads want to issue in this M-iter
|
||||
@@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_weight_preshuffle = BlockWeightPreshuffle();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// pingpong buffer for B
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// A tile in LDS
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
auto a_lds_blocks = generate_tuple(
|
||||
[&](auto i) {
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + smem_size * i.value));
|
||||
return make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
|
||||
BlockWeightPreshuffle::MakeABlockDistributionEncode());
|
||||
auto&& windows_result =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr);
|
||||
auto&& a_copy_dram_window = windows_result.template get<0>();
|
||||
auto&& a_lds_windows = windows_result.template get<1>();
|
||||
auto a_copy_lds_windows = generate_tuple(
|
||||
[&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); },
|
||||
number<2>{});
|
||||
// Block GEMM
|
||||
auto block_weight_preshuffle = BlockWeightPreshuffle();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
auto a_load_windows = generate_tuple(
|
||||
[&](auto i) -> decltype(auto) {
|
||||
return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]);
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(b_flat_dram_block_window_tmp
|
||||
.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp * NIterPerWarp>{},
|
||||
number<flatKPerWarp * KIterPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
ABlockTile a_global_tile;
|
||||
BBlockTile b_global_tile[2];
|
||||
|
||||
// // Prefetch A0
|
||||
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Prefill A0
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
|
||||
// Prefetch A1
|
||||
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// MAIN LOOP
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i_global_read = amd_wave_read_first_lane(2);
|
||||
do
|
||||
{
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_load_windows[I0],
|
||||
b_global_tile[0],
|
||||
b_flat_distribution);
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_load_windows[I1],
|
||||
b_global_tile[1],
|
||||
b_flat_distribution);
|
||||
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
i_global_read += 2;
|
||||
} while(i_global_read < num_loop);
|
||||
}
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
{
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
|
||||
block_sync_lds();
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
|
||||
Last2ndHotLoopScheduler();
|
||||
}
|
||||
{
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution);
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
// TailHotLoopScheduler();
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
// called from universal gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -1038,23 +726,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const ADataType& a) { return a; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp[number<0>{}],
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
a_element_func,
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
@@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr auto PassThrough = [](const ADataType& a) { return a; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
@@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1723,7 +1723,7 @@ struct QuantGemmKernel
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
@@ -1735,7 +1735,7 @@ struct QuantGemmKernel
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -1762,7 +1762,7 @@ struct QuantGemmKernel
|
||||
m = kargs.M;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m);
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
@@ -1772,7 +1772,7 @@ struct QuantGemmKernel
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
@@ -1788,7 +1788,7 @@ struct QuantGemmKernel
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
m,
|
||||
n);
|
||||
}
|
||||
@@ -1796,7 +1796,7 @@ struct QuantGemmKernel
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1812,14 +1812,14 @@ struct QuantGemmKernel
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
@@ -1828,7 +1828,7 @@ struct QuantGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1840,14 +1840,14 @@ struct QuantGemmKernel
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
@@ -1856,89 +1856,7 @@ struct QuantGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
n);
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1961,37 +1879,10 @@ struct QuantGemmKernel
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Only for BQuantGrouped DoubleSmemBuffer is supported
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -377,8 +374,7 @@ struct QuantGroupedGemmKernel
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
void* smem_ptr,
|
||||
const QuantGroupedGemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -399,27 +395,22 @@ struct QuantGroupedGemmKernel
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr);
|
||||
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +426,7 @@ struct QuantGroupedGemmKernel
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
|
||||
* batch.
|
||||
@@ -449,7 +440,7 @@ struct QuantGroupedGemmKernel
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr,
|
||||
const QuantGroupedGemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -481,7 +472,7 @@ struct QuantGroupedGemmKernel
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
@@ -491,13 +482,13 @@ struct QuantGroupedGemmKernel
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -510,14 +501,14 @@ struct QuantGroupedGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
@@ -526,7 +517,7 @@ struct QuantGroupedGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -538,14 +529,14 @@ struct QuantGroupedGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
@@ -554,7 +545,7 @@ struct QuantGroupedGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
|
||||
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
|
||||
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
|
||||
// temporarily
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
|
||||
{
|
||||
|
||||
@@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t n,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
@@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
|
||||
ADataType* p_a_lds_pong =
|
||||
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
@@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong,
|
||||
index_t n = 0) const // Default value for non-preshuffle case
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
@@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
bq_dram_block_window_tmp,
|
||||
n,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -584,8 +583,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong,
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
@@ -598,8 +596,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
bq_dram_block_window_tmp,
|
||||
n, // dummy value, won't be used
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user