mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
clang format
This commit is contained in:
@@ -220,9 +220,9 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
{
|
||||
std::cout << "Running with fp16 data type" << std::endl;
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
@@ -264,7 +264,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
{
|
||||
std::cout << "Running with warp tile size 16x16" << std::endl;
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
|
||||
47
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
47
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
@@ -590,27 +590,27 @@ struct FlatmmKernel
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
CK_TILE_DEVICE static void RunFlatmm2(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
@@ -651,7 +651,8 @@ struct FlatmmKernel
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value) && FlatmmPipeline::DoubleSmemBuffer == false)
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value) &&
|
||||
FlatmmPipeline::DoubleSmemBuffer == false)
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
@@ -667,15 +668,15 @@ struct FlatmmKernel
|
||||
else
|
||||
{
|
||||
RunFlatmm2(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -51,7 +51,8 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV2
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
@@ -499,7 +500,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using Base = BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>;
|
||||
using Base = BaseFlatmmPipelineAGmemBGmemCRegV2<Problem>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -510,8 +511,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
@@ -585,7 +584,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
static constexpr MfmaConfig GetMfmaConfig()
|
||||
{
|
||||
|
||||
|
||||
// K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1
|
||||
if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
|
||||
std::is_same_v<ADataType, fp8_t>) ||
|
||||
@@ -645,7 +643,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
@@ -1024,11 +1021,13 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr ((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
|
||||
if constexpr((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
|
||||
B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
@@ -339,36 +339,36 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
//using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
@@ -29,14 +29,14 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
||||
@@ -44,12 +44,12 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
@@ -18,31 +18,31 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
//using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
|
||||
@@ -33,7 +33,6 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
|
||||
@@ -73,7 +72,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType);
|
||||
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
|
||||
@@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
@@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -79,7 +81,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType);
|
||||
static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -105,18 +107,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
static constexpr auto warp_m = WarpTile::at(idxM);
|
||||
static constexpr auto warp_n = WarpTile::at(idxN);
|
||||
static constexpr auto warp_k = WarpTile::at(idxK);
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -129,7 +127,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
@@ -508,11 +505,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr ((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
|
||||
if constexpr((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
|
||||
B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
@@ -587,7 +586,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
@@ -606,8 +605,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
@@ -680,12 +677,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// Prefetch A0
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
@@ -700,18 +696,17 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
@@ -734,8 +729,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
@@ -783,7 +776,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
@@ -859,7 +851,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
@@ -896,7 +887,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -936,7 +927,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
@@ -1028,7 +1018,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
@@ -1050,7 +1039,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -1071,5 +1059,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
@@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV3
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>;
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3<Problem>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -118,10 +120,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
static constexpr auto warp_n = WarpTile::at(idxN);
|
||||
static constexpr auto warp_k = WarpTile::at(idxK);
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -135,7 +133,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
@@ -514,11 +511,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr ((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {
|
||||
if constexpr((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
|
||||
B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
@@ -552,7 +551,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
@@ -570,7 +568,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
// static assert that warptile is 16x16 and not 32x32
|
||||
static_assert(WG::kM == 16 && WG::kN == 16, "For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32");
|
||||
static_assert(WG::kM == 16 && WG::kN == 16,
|
||||
"For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
@@ -596,7 +595,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A DRAM tile window for load
|
||||
|
||||
auto a_copy_dram_window_tmp =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -636,8 +635,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
move_tile_window(a_copy_lds_window_pong(AIter), {AIter * ACopyPerLoadM, 0});
|
||||
});
|
||||
|
||||
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
@@ -710,7 +707,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// Prefetch A0
|
||||
// Prefetch A0
|
||||
|
||||
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum>
|
||||
a_block_tile;
|
||||
@@ -733,7 +730,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
// Prefill A0
|
||||
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
store_tile(a_copy_lds_window_ping(AIter),
|
||||
@@ -742,7 +739,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
// Prefetch A1
|
||||
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter));
|
||||
@@ -771,13 +768,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -801,7 +795,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
// prefetch B(2i+1)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BLoadGap) &&
|
||||
@@ -872,7 +865,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
|
||||
// Next K
|
||||
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -895,7 +887,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
|
||||
// prefetch B(2i+2)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BLoadGap) &&
|
||||
@@ -1153,7 +1144,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -1174,5 +1164,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user