mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Change gemm0 to iterate along kN0 so that BlockGemm can overlap with maksing and siLu
This commit is contained in:
@@ -579,20 +579,10 @@ struct HstuAttentionFwdKernel
|
||||
make_tuple(kargs.seq_stride_q, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
if constexpr(HstuAttentionPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -604,7 +594,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{},
|
||||
number<HstuAttentionPipeline::kK0>{}),
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
@@ -645,22 +635,19 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(HstuAttentionPipeline::kQLoadOnce)
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{});
|
||||
else
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kK0>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram,
|
||||
[&]() {
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{}, number<HstuAttentionPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{},
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
|
||||
@@ -148,7 +148,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -157,9 +157,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(2 <= k1_loops);
|
||||
|
||||
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
|
||||
@@ -178,19 +176,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
@@ -204,13 +197,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(
|
||||
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
@@ -243,8 +237,11 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
|
||||
statically_indexed_array<SaccBlockTileType, k1_loops> sacc_tiles;
|
||||
statically_indexed_array<PcompBlockTileType, k1_loops> pcomp_tiles;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType& x) {
|
||||
@@ -274,7 +271,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
make_tuple(number<kM0>{}, number<kK1>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
@@ -303,105 +300,98 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
index_t i_loop = 0;
|
||||
|
||||
do
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
if constexpr(i_k0 == 0)
|
||||
clear_tile(s_acc);
|
||||
|
||||
k_tile = load_tile(k_dram_window);
|
||||
if constexpr(i_k0 < k0_loops - 2)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(sacc_tiles[i_k1]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(
|
||||
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
|
||||
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKLdsBuffers>{}]);
|
||||
|
||||
sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
sacc_tiles[i_k1],
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kK1});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; },
|
||||
sacc_tiles[i_k1]);
|
||||
}
|
||||
|
||||
if constexpr(HstuMask::IsMasking)
|
||||
{
|
||||
set_tile_if(
|
||||
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else if constexpr(kPadSeqLenK)
|
||||
{
|
||||
set_tile_if(
|
||||
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len &&
|
||||
i_loop < num_loops - 1)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
|
||||
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kK1;
|
||||
});
|
||||
|
||||
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
// prefetch first v_tile
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
}
|
||||
|
||||
if constexpr(HstuMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else if constexpr(kPadSeqLenK)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && i_loop < num_loops - 1)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
|
||||
auto s = cast_tile<CompDataType>(s_acc);
|
||||
|
||||
tile_elementwise_inout(f_silu, s);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7f);
|
||||
// load one k_tile for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -426,59 +416,50 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
store_tile(v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
|
||||
}
|
||||
};
|
||||
|
||||
const auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, s));
|
||||
else
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
|
||||
}();
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
}();
|
||||
|
||||
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
gemm_1(o_acc, p, v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
|
||||
}
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(
|
||||
v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the
|
||||
// prefetch
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
|
||||
|
||||
// the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2
|
||||
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
@@ -61,8 +61,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -100,8 +100,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -147,8 +147,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
@@ -300,8 +300,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
Problem::BlockFmhaShape::kK1,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ struct HstuAttentionFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 32, 128, 32, 128>;
|
||||
using type = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user