clang format

This commit is contained in:
AviralGoelAMD
2025-07-16 21:52:16 +00:00
parent d43f035761
commit 6d38fd3673
10 changed files with 118 additions and 143 deletions

View File

@@ -220,9 +220,9 @@ int run_flatmm_example(int argc, char* argv[])
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
{
std::cout << "Running with fp16 data type" << std::endl;
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
@@ -264,7 +264,7 @@ int main(int argc, char* argv[])
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
{
std::cout << "Running with warp tile size 16x16" << std::endl;
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}

47
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Executable file → Normal file
View File

@@ -590,27 +590,27 @@ struct FlatmmKernel
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr);
}
}
}
CK_TILE_DEVICE static void RunFlatmm2(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
@@ -651,7 +651,8 @@ struct FlatmmKernel
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value) && FlatmmPipeline::DoubleSmemBuffer == false)
is_any_of<EDataType, fp16_t, bf16_t>::value) &&
FlatmmPipeline::DoubleSmemBuffer == false)
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<scheduler_type>(a_ptr,
@@ -667,15 +668,15 @@ struct FlatmmKernel
else
{
RunFlatmm2(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
};

View File

@@ -51,7 +51,8 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV2
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
{
if(tail_number == TailNumber::Odd)
{
@@ -499,7 +500,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>
{
using Base = BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>;
using Base = BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -510,8 +511,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
@@ -585,7 +584,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
static constexpr MfmaConfig GetMfmaConfig()
{
// K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1
if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
std::is_same_v<ADataType, fp8_t>) ||
@@ -645,7 +643,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
// clang-format on
}
static constexpr bool DoubleSmemBuffer = true;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
@@ -1024,11 +1021,13 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
}
else
{
if constexpr ((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
if constexpr((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
{
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA

View File

@@ -339,36 +339,36 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
//using ALayout = remove_cvref_t<typename Problem::ALayout>;
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()

View File

@@ -15,9 +15,9 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
@@ -29,14 +29,14 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"

View File

@@ -13,9 +13,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
@@ -44,12 +44,12 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"

View File

@@ -18,31 +18,31 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
//using ALayout = remove_cvref_t<typename Problem::ALayout>;
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
}
// 3d + padding
template <typename Problem>

View File

@@ -33,7 +33,6 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV1
}
};
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
struct WeightPreshufflePipelineAGmemBGmemCRegV1
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
@@ -73,7 +72,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType);
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr auto I0 = number<0>();

View File

@@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
{
if(tail_number == TailNumber::Odd)
{
@@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
};
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
struct WeightPreshufflePipelineAGmemBGmemCRegV2
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -79,7 +81,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType);
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr auto I0 = number<0>();
@@ -105,18 +107,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = 16 / sizeof(ADataType);
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
static constexpr auto TailNum = Problem::TailNum;
static constexpr index_t K1 = 16 / sizeof(ADataType);
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto warp_m = WarpTile::at(idxM);
static constexpr auto warp_n = WarpTile::at(idxN);
static constexpr auto warp_k = WarpTile::at(idxK);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -129,7 +127,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
// clang-format on
}
static constexpr bool DoubleSmemBuffer = true;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
@@ -508,11 +505,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
}
else
{
if constexpr ((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
if constexpr((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
{
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -587,7 +586,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
@@ -606,8 +605,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
@@ -680,12 +677,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
@@ -700,18 +696,17 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -734,8 +729,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
});
__builtin_amdgcn_sched_barrier(0);
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
@@ -783,7 +776,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
@@ -859,7 +851,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
@@ -896,7 +887,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
// tail
if constexpr(TailNum == TailNumber::Even)
{
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -936,7 +927,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
@@ -1028,7 +1018,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
@@ -1050,7 +1039,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
});
}
return c_block_tile;
}
@@ -1071,5 +1059,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
}
};
} // namespace ck_tile

View File

@@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
{
if(tail_number == TailNumber::Odd)
{
@@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3
};
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>
struct WeightPreshufflePipelineAGmemBGmemCRegV3
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>;
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -118,10 +120,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
static constexpr auto warp_n = WarpTile::at(idxN);
static constexpr auto warp_k = WarpTile::at(idxK);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -135,7 +133,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
// clang-format on
}
static constexpr bool DoubleSmemBuffer = true;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
@@ -514,11 +511,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
}
else
{
if constexpr ((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
if constexpr((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
{
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -552,7 +551,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
}
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
@@ -570,7 +568,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// static assert that warptile is 16x16 and not 32x32
static_assert(WG::kM == 16 && WG::kN == 16, "For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32");
static_assert(WG::kM == 16 && WG::kN == 16,
"For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
@@ -596,7 +595,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
// A DRAM tile window for load
auto a_copy_dram_window_tmp =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -636,8 +635,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
move_tile_window(a_copy_lds_window_pong(AIter), {AIter * ACopyPerLoadM, 0});
});
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
@@ -710,7 +707,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
// Prefetch A0
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum>
a_block_tile;
@@ -733,7 +730,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// Prefill A0
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
store_tile(a_copy_lds_window_ping(AIter),
@@ -742,7 +739,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
// Prefetch A1
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter));
@@ -771,13 +768,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
});
__builtin_amdgcn_sched_barrier(0);
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -801,7 +795,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// prefetch B(2i+1)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BLoadGap) &&
@@ -872,7 +865,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
// Next K
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -895,7 +887,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// prefetch B(2i+2)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BLoadGap) &&
@@ -1153,7 +1144,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
});
}
return c_block_tile;
}
@@ -1174,5 +1164,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
}
};
} // namespace ck_tile