Rename more tile size constants

This commit is contained in:
PoYen, Chen
2024-07-23 09:30:05 +00:00
parent 99c1d463de
commit 52b47810bb
5 changed files with 32 additions and 32 deletions

View File

@@ -60,7 +60,7 @@ struct FmhaFwdAppendKVKernel
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
#undef _SS_
@@ -439,7 +439,7 @@ struct FmhaFwdAppendKVKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
@@ -453,7 +453,7 @@ struct FmhaFwdAppendKVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
@@ -476,7 +476,7 @@ struct FmhaFwdAppendKVKernel
return pad_tensor_view(
vnew_dram_transposed,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
@@ -490,7 +490,7 @@ struct FmhaFwdAppendKVKernel
return pad_tensor_view(
vnew_dram_naive,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
@@ -616,15 +616,15 @@ struct FmhaFwdAppendKVKernel
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{i_n0, 0});
auto v_dram_window = make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
{0, kargs.seqlen_k + i_n0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, kargs.seqlen_k + i_n0});
auto vnew_dram_window = make_tile_window(
vnew_dram,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kN0>{}),
{0, i_n0});
auto vnew_dram_window =
make_tile_window(vnew_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, i_n0});
if constexpr(kApplyRoPE)
{

View File

@@ -7,15 +7,15 @@
namespace ck_tile {
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kTileSizeDv_>
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
struct FmhaFwdAppendKVTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static_assert(kK0 == kTileSizeDv);
static_assert(kK0 == kN1);
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,

View File

@@ -22,10 +22,10 @@ struct BlockFmhaFwdAppendKVPipeline
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kTileSizeDv = Problem::kTileSizeDv;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -94,8 +94,8 @@ struct BlockFmhaFwdAppendKVPipeline
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
const VnewElementFunction& vnew_element_func,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,

View File

@@ -34,7 +34,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kTileSizeDv;
constexpr index_t kKPerBlock = Problem::kN1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct!
@@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t kKPerBlock = Problem::kN0;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)

View File

@@ -13,7 +13,7 @@ template <typename QDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kTileSizeDv_,
index_t kN1_,
bool IsVLayoutRowMajor_,
bool kIsGroupMode_,
typename Traits_>
@@ -27,10 +27,10 @@ struct BlockFmhaFwdAppendKVPipelineProblem
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kTileSizeDv = kTileSizeDv_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<IsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,