mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Rename more tile size constants
This commit is contained in:
@@ -60,7 +60,7 @@ struct FmhaFwdAppendKVKernel
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
|
||||
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
|
||||
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
|
||||
#undef _SS_
|
||||
@@ -439,7 +439,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
@@ -453,7 +453,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
@@ -476,7 +476,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
vnew_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
@@ -490,7 +490,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
vnew_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
@@ -616,15 +616,15 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, kargs.seqlen_k + i_n0});
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, kargs.seqlen_k + i_n0});
|
||||
|
||||
auto vnew_dram_window = make_tile_window(
|
||||
vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, i_n0});
|
||||
auto vnew_dram_window =
|
||||
make_tile_window(vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, i_n0});
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
|
||||
@@ -7,15 +7,15 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kTileSizeDv_>
|
||||
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
|
||||
struct FmhaFwdAppendKVTilePartitioner
|
||||
{
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
|
||||
static_assert(kK0 == kTileSizeDv);
|
||||
static_assert(kK0 == kN1);
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
|
||||
@@ -22,10 +22,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = Problem::kM0;
|
||||
static constexpr index_t kN0 = Problem::kN0;
|
||||
static constexpr index_t kK0 = Problem::kK0;
|
||||
static constexpr index_t kTileSizeDv = Problem::kTileSizeDv;
|
||||
static constexpr index_t kM0 = Problem::kM0;
|
||||
static constexpr index_t kN0 = Problem::kN0;
|
||||
static constexpr index_t kK0 = Problem::kK0;
|
||||
static constexpr index_t kN1 = Problem::kN1;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
@@ -94,8 +94,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
|
||||
const KnewElementFunction& knew_element_func,
|
||||
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
|
||||
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
|
||||
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
|
||||
|
||||
@@ -34,7 +34,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeDv;
|
||||
constexpr index_t kKPerBlock = Problem::kN1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// TODO: not correct!
|
||||
@@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::kN0;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
|
||||
@@ -13,7 +13,7 @@ template <typename QDataType_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kK0_,
|
||||
index_t kTileSizeDv_,
|
||||
index_t kN1_,
|
||||
bool IsVLayoutRowMajor_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
@@ -27,10 +27,10 @@ struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kK0 = kK0_;
|
||||
static constexpr index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kK0 = kK0_;
|
||||
static constexpr index_t kN1 = kN1_;
|
||||
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor_,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
|
||||
Reference in New Issue
Block a user