mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
support even odd k in tail
This commit is contained in:
@@ -67,6 +67,9 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenFlatmmShape =
|
||||
ck_tile::TileFlatmmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
@@ -74,15 +77,35 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenFlatmmShape>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -127,21 +150,73 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
@@ -3,13 +3,29 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #define BLOCKWISE_LOADSTORE
|
||||
// #define FINEGRADE_LOADSTORE
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
@@ -24,6 +40,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -54,6 +74,28 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp;
|
||||
static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum;
|
||||
static constexpr index_t BloadGap = MIterPerWarp / 2;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -76,88 +118,167 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
#if 0
|
||||
static_for<0, 7, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
// Keypoint of pipeline optimize is workload balance in time
|
||||
// instruction schedule example(128X256X256, 1X4, 16X16X128):
|
||||
// Iter MNK MFMA ds_read ds_write A_load b_load
|
||||
// -1 M6N3: 60 2 - - -
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 - - - -
|
||||
// -1 M7N2: 63 - - - -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - -
|
||||
// 0 M0N1: 2 - - - 2
|
||||
// 0 M0N2: 3 - - - -
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - -
|
||||
// 0 M1N1: 6 - - - 4
|
||||
// 0 M1N2: 7 - - - -
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - -
|
||||
// 0 M2N1: 10 - - - 6
|
||||
// 0 M2N2: 11 - - - -
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - -
|
||||
// 0 M3N1: 14 - - - 8
|
||||
// 0 M3N2: 15 - - - -
|
||||
// 0 M3N3: 16 12 - - -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 - - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N3: 20 14 - - -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 - - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N3: 24 16 - - -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 - - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N3: 28 17 - - -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 - - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N3: 32 18 - - -
|
||||
// 0 M0N0K1: 33 - - - -
|
||||
// 0 M0N1: 34 - - - 10
|
||||
// 0 M0N2: 35 - - - -
|
||||
// 0 M0N3: 36 20 - - -
|
||||
// 0 M1N0: 37 - - - -
|
||||
// 0 M1N1: 38 - - - 12
|
||||
// 0 M1N2: 39 - - - -
|
||||
// 0 M1N3: 40 22 - - -
|
||||
// 0 M2N0: 41 - - - -
|
||||
// 0 M2N1: 42 - - - 14
|
||||
// 0 M2N2: 43 - - - -
|
||||
// 0 M2N3: 44 24 - - -
|
||||
// 0 M3N0: 45 - 5 - -
|
||||
// 0 M3N1: 46 - - - 16
|
||||
// 0 M3N2: 47 - - - -
|
||||
// 0 M3N3: 48 26 - - -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 - - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N3: 52 28 - - -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 - - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N3: 56 30 - - -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 - - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N3: 60 2 - - -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 - - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
static_for<0, 7, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 1
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 0
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler()
|
||||
{
|
||||
#if 0
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
@@ -199,6 +320,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
@@ -218,29 +340,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp;
|
||||
constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum;
|
||||
constexpr index_t BloadGap = MIterPerWarp / 2;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
@@ -263,7 +362,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
@@ -399,7 +498,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
|
||||
// Prefetch A0
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
@@ -438,7 +537,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// {
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
// }
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
#else
|
||||
@@ -449,7 +548,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
@@ -485,9 +584,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
|
||||
index_t iCounter = num_loop / 2 - 1;
|
||||
while(iCounter > 0)
|
||||
// if constexpr(HasMainLoop)
|
||||
// {
|
||||
while(iCounter > 0)
|
||||
{
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -531,7 +632,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
#ifndef BLOCKWISE_LOADSTORE
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+1)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
@@ -586,7 +687,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
//Next K
|
||||
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -629,7 +730,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
#ifndef BLOCKWISE_LOADSTORE
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+2)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
@@ -685,129 +786,191 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
#ifdef BLOCKWISE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
#endif
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
#ifndef BLOCKWISE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
#endif
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(loopK)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// Prefill A(loopK)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
//block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping);
|
||||
//block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping);
|
||||
|
||||
TailHotLoopScheduler();
|
||||
TailHotLoopScheduler();
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
|
||||
});
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
});
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
});
|
||||
});
|
||||
// block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong);
|
||||
// block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong);
|
||||
|
||||
// TailHotLoopScheduler();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// TailHotLoopScheduler();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(loopK)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
// }
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -160,6 +160,148 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct FlatmmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
|
||||
Reference in New Issue
Block a user