mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Separate kN0Sub from kK0 to be used for flexible tile tuning for whole_k_prefetch pipeline
This commit is contained in:
@@ -1349,7 +1349,7 @@ struct FmhaFwdKernel
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{},
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kN0Sub = BlockFmhaShape::kN0Sub;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
@@ -177,14 +178,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
// usually kN0 is 128, kK1 is 32/16
|
||||
// usually kN0 is 128, kN0Sub/kK1 is 32/16
|
||||
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
|
||||
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
|
||||
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
@@ -196,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
@@ -227,7 +231,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -236,13 +240,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
auto k_tiles = [&]() {
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
return statically_indexed_array<k_tile_type, k1_loops>{};
|
||||
return statically_indexed_array<k_tile_type, n0_loops>{};
|
||||
else
|
||||
return statically_indexed_array<k_tile_type, 1>{};
|
||||
}();
|
||||
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -258,11 +262,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
@@ -270,11 +274,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -371,75 +376,75 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - 1)
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, k1_loops, 1>{}([&](auto ii_k1) {
|
||||
k_tiles[number<ii_k1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
@@ -447,87 +452,87 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// prefetch k_tile for next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
// prefetch other k_tiles for next iteration
|
||||
if constexpr(i_k1 >= NumPrefetchV)
|
||||
if constexpr(i_n0 >= NumPrefetchV)
|
||||
{
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
else // only preload one unroll of K for next iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -535,14 +540,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -677,7 +682,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
@@ -696,7 +701,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
@@ -170,7 +170,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
@@ -213,7 +213,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
@@ -320,7 +320,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
@@ -455,7 +455,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kK1,
|
||||
Problem::BlockFmhaShape::kN0Sub,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
@@ -47,23 +47,24 @@ struct TileFmhaShape
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
static_assert(kQKHeaddim % kK0 == 0 || kN0 % kN0Sub == 0, "Check failed!");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
|
||||
Reference in New Issue
Block a user