[CK_TILE] More fmha splitkv optimizations (#1588)

* Use pre-defined constants for readability

* Use vector write for o_acc tensor

* Remove no-longer used policy method

* Deprecate no-longer used policy/pipeline

* Specify gemm0/gemm1 block warps separately in codegen

* Fix wrong ps_idx creation logic

* Add single-warp block gemm

* Supoprt single-warp gemm0

* Make MakeCBlockTile() as static method

* Use MakeCBlockTile() to get underlying tile distribution

* Use kNumGemm1Warps to compute # threads for gemm1

* Put normal case in the if clause

* Refine fmha splitkv block mapping

* Refine & fix the lse_acc/o_acc layout

* Fix wrong LDS size for K tile

* Use kK0=64 for hdim=128,256 fmha splitkv kernels

* Use kK1=64 for hdim=32,64,128 fmha splitkv kernels

* Undo kK0/kK1 changes

* Use more reasonable GetAlignmentV() computation

* Using store_tile() in fmha splitkv kernel epilogue
This commit is contained in:
Po Yen Chen
2024-10-26 18:35:45 +08:00
committed by GitHub
parent 37f7afed1e
commit 54f0e6f4bb
22 changed files with 422 additions and 199 deletions

View File

@@ -69,7 +69,8 @@ struct FmhaFwdKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
@@ -85,7 +86,8 @@ struct FmhaFwdKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +

View File

@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
@@ -81,7 +82,8 @@ struct FmhaFwdSplitKVKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
@@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return pad_tensor_view(

View File

@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);

View File

@@ -64,6 +64,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
@@ -252,11 +255,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,

View File

@@ -9,11 +9,20 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
};
} // namespace ck_tile

View File

@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;

View File

@@ -242,11 +242,11 @@ struct BlockFmhaPipelineQRKSVS
{seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);

View File

@@ -314,11 +314,11 @@ struct BlockFmhaPipelineQRKSVSAsync
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);

View File

@@ -9,9 +9,10 @@
namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct BlockFmhaPipelineQSKSVS
struct [[deprecated]] BlockFmhaPipelineQSKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;

View File

@@ -15,6 +15,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
@@ -64,13 +65,28 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
}
template <typename Problem>
@@ -80,7 +96,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
@@ -129,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
/// NOTICE: we no-longer use this policy.
template <>
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{
static constexpr bool QLoadOnce = false;
@@ -364,12 +384,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
// TODO: not correct!
if constexpr(total_pixels > 4)
return 4;
else
return 2;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
else
{
@@ -383,10 +406,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using BlockGemm = remove_cvref_t<decltype(QXPolicy::template GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
@@ -395,10 +416,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
@@ -449,44 +468,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return max(SingleKSize, SingleVSize);
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
@@ -886,36 +873,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
template <typename Problem, typename BlockGemm>
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// Construct C-Block-HostTensor
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
return BlockGemm::MakeCBlockTile().get_tile_distribution();
}
template <typename Problem>
@@ -972,7 +933,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,

View File

@@ -21,10 +21,15 @@ struct TileFmhaShape
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumWarps =
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}));
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"

View File

@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegOneWarpV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static_assert(kBlockSize == get_warp_size(), "Check failed!");
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = 0;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!");
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -22,7 +22,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = make_array<index_t>(get_block_id(), get_lane_id());
const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();