Add implementation of fwd splitkv on no_softmax path

This commit is contained in:
Qianfeng Zhang
2026-04-15 07:14:40 +00:00
parent a95f64601d
commit 9279af33f1
18 changed files with 2530 additions and 22 deletions

View File

@@ -22,6 +22,8 @@
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
#include "hstu_attention_batched_forward_splitkv_dispatch.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
@@ -190,7 +192,7 @@ template <typename InOutDataType,
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.seqlen_q) == 128)
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
@@ -199,11 +201,37 @@ void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGro
MaxK,
128>::Run(param, stream);
else
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
{
const bool disable_fwd_splitkv = []() {
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
if(env_p == nullptr)
return false;
return static_cast<bool>(atoi(env_p));
}();
// ToDo: enable splitkv when kUseSoftmax is true
if(!disable_fwd_splitkv && !kUseSoftmax &&
shall_use_splitkv(param.num_batch, param.num_head, param.seqlen_q))
{
if constexpr(!kUseSoftmax)
{
batched_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param,
stream);
};
}
else
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};
};

View File

@@ -0,0 +1,274 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/host/hip_check_error.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_fwd_setting.hpp"
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
#include "hstu_attention_params.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
using HstuAttentionFwdTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <bool kIsCrossAttention>
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
false, // kUseGroup
false, // kIsJagged
kHasBias,
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
{
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionFwdTileSetting::kN0 == 0);
const bool pad_headdim_qk =
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
BOOL_SWITCH_3(
pad_seqlen_k,
kPadSeqLenK,
pad_headdim_qk,
kPadHeadDimQK,
pad_headdim_v,
kPadHeadDimV,
[&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
using HstuEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
OaccDataType,
OaccDataType, // keep output as OaccDataType
kPadSeqLenQ,
kPadHeadDimV,
false>>;
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<
HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
}
else
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
};
});
});
};
{
using HstuCombinePipelineProblem =
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
OaccDataType,
ODataType,
false /* kIsJagged */,
kUseSoftmax,
HstuAttentionCombineTileSetting>;
const bool pad_headdim_o =
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
using HstuTraits = ck_tile::
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
using HstuEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
ODataType,
kPadSeqLenQ,
kPadHeadDimO,
false>>;
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
HstuCombinePipelineProblem,
HstuTraits>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
});
};
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.num_splits,
param.seqlen_q,
param.is_cross_attention ? param.seqlen_kv
: param.seqlen_q,
param.hdim_qk,
param.hdim_v,
param.num_head,
param.scale_s,
param.attn_scale,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,
param.seq_stride_bias,
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.batch_stride_q,
param.batch_stride_k,
param.batch_stride_v,
param.batch_stride_bias,
param.num_targets_ptr,
param.contextual_seqlen,
param.window_size,
param.min_full_attn_seqlen,
param.p_drop,
param.philox_seed,
param.philox_offset);
}();
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
param.num_head,
param.seqlen_q,
param.hdim_v,
param.num_splits,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.o_ptr,
param.batch_stride_o,
param.seq_stride_o,
param.nhead_stride_o,
param.seqlen_q,
param.num_head,
param.num_splits,
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
};
};

View File

@@ -484,3 +484,36 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_
// mtile-64 can be added through tuning/verification
return 64;
};
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
{
int num_CUs = get_number_of_cu();
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
// assume each CU can run two work-groups, common cases for hdim128
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
};
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
{
// Please tune the threshold here
const float threshold = 0.8f;
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
return true;
return false;
};
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
{
int i = 2;
// Please tune the threshold here
const float threshold = 1.5f;
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
i++;
return i;
};

View File

@@ -0,0 +1,282 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include "hstu_block_masking.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
#endif
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
template <typename HstuAttentionPipeline_, typename EpiloguePipeline_>
struct HstuAttentionFwdSplitKVCombineKernel
{
using HstuAttentionPipeline = ck_tile::remove_cvref_t<HstuAttentionPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
using OaccDataType =
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::ODataType>;
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO;
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
struct HstuAttentionCombineEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct HstuAttentionBatchedCombineBaseKargs
{
const void* o_acc_ptr;
void* o_ptr;
ck_tile::index_t batch_stride_o;
ck_tile::index_t seq_stride_o;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t seqlen_q;
ck_tile::index_t num_head;
ck_tile::index_t num_splits;
ck_tile::index_t hdim_v;
};
struct HstuAttentionJaggedCombineBaseKargs
{
const void* o_acc_ptr;
void* o_ptr;
ck_tile::index_t seq_stride_o;
ck_tile::index_t nhead_stride_o;
const int32_t* seq_q_offsets_ptr;
ck_tile::index_t num_head;
ck_tile::index_t num_splits;
ck_tile::index_t hdim_v;
ck_tile::index_t seqlen_q;
};
struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs
{
};
struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs
{
};
using Kargs = std::
conditional_t<kIsJagged, HstuAttentionJaggedCombineKargs, HstuAttentionBatchedCombineKargs>;
template <bool Cond = !kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
void* o_ptr,
ck_tile::index_t batch_stride_o,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t seqlen_q,
ck_tile::index_t num_head,
ck_tile::index_t num_splits, // number of splitted seqlen_kv
ck_tile::index_t hdim_v)
{
Kargs kargs{o_acc_ptr,
o_ptr,
batch_stride_o,
seq_stride_o,
nhead_stride_o,
seqlen_q,
num_head,
num_splits,
hdim_v};
return kargs;
}
template <bool Cond = kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
void* o_ptr,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_o,
const void* seq_q_offsets_ptr,
ck_tile::index_t num_head,
ck_tile::index_t num_splits, // number of splitted seqlen_kv
ck_tile::index_t hdim_v)
{
Kargs kargs{o_acc_ptr,
o_ptr,
seq_stride_o,
nhead_stride_o,
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
num_head,
num_splits,
hdim_v,
0 /* seqlen_q will be updated later */};
return kargs;
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_)
{
ck_tile::index_t num_tile_in_seqlen =
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM);
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
return dim3(batch_size_, nhead_, num_tile_in_seqlen);
#else
return dim3(num_tile_in_seqlen),
nhead_,
batch_size_);
#endif
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
ignore = kargs;
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
const index_t i_batch = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_tile_m = blockIdx.z;
#else
const index_t i_tile_m = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
#endif
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(kargs);
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsJagged)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
// assume o_acc is in compact shape of [batch_size, max_seqlen, num_head, num_splits,
// hdim]
batch_offset_o_acc = query_start * kargs.num_head * kargs.num_splits * kargs.hdim_v;
batch_offset_o = query_start * kargs.seq_stride_o;
kargs.seqlen_q =
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
}
else
{
// assume o_acc is in compact shape of [batch_size, seqlen_q, num_head, num_splits,
// hdim]
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.seqlen_q *
kargs.num_head * kargs.num_splits * kargs.hdim_v;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
index_t i_m0;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
if(kargs.seqlen_q <= i_m0)
return;
// assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim]
const OaccDataType* o_acc_ptr =
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.num_splits * kargs.hdim_v +
batch_offset_o_acc;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Oacc DRAM and Oacc DRAM window
auto seq_stride_o_acc = kargs.num_head * kargs.num_splits * kargs.hdim_v;
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(seq_stride_o_acc, 1),
number<HstuAttentionPipeline::kAlignmentO>{},
number<1>{});
const auto o_acc_dram =
pad_tensor_view(o_acc_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
sequence<false, kPadHeadDimO>{});
auto o_acc_dram_window =
make_tile_window(o_acc_dram,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
{i_m0, 0});
auto o_acc_tile = [&]() {
return HstuAttentionPipeline{}(o_acc_dram_window, kargs.hdim_v, kargs.num_splits);
}();
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.seq_stride_o, 1),
number<HstuAttentionPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(o_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
sequence<false, kPadHeadDimO>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
{i_m0, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
namespace ck_tile {
struct HstuAttentionFwdSplitKVCombinePipelinePolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM;
constexpr index_t kKPerBlock = Problem::kOHeaddim;
constexpr index_t NumWarps = Problem::NumWarps;
constexpr index_t KVector = GetAlignmentOacc<Problem>();
constexpr index_t OtherK = kKPerBlock / KVector;
if constexpr(kKPerBlock == Problem::kSubOHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
constexpr index_t KPerThread = KVector;
// try to assign more consecutive threads on dim-K
constexpr index_t KThreads = OtherK;
static_assert(KThreads <= get_warp_size(), "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ensure KThreads be power-of-2 integer
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, KVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
return Problem::GetOaccDramTileAccessMaxVectorSize();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
// should be same as GetAlignmentOacc() since o_tile will use the same encoding as
// o_acc_tile
return GetAlignmentOacc<Problem>();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_tile_setting_define.hpp"
template <ck_tile::index_t kOHeaddim>
struct HstuAttentionFwdSplitKVCombineTileSetting;
template <>
struct HstuAttentionFwdSplitKVCombineTileSetting<64>
{
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<32, 4, 64>;
};
template <>
struct HstuAttentionFwdSplitKVCombineTileSetting<96>
{
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 96>;
};
template <>
struct HstuAttentionFwdSplitKVCombineTileSetting<128>
{
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 128>;
};
template <>
struct HstuAttentionFwdSplitKVCombineTileSetting<256>
{
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<8, 4, 256>;
};

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,8 @@
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
#include "hstu_attention_group_forward_splitkv_dispatch.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
@@ -185,11 +187,36 @@ void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFw
MaxK,
128>::Run(param, stream);
else
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
{
const bool disable_fwd_splitkv = []() {
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
if(env_p == nullptr)
return false;
return static_cast<bool>(atoi(env_p));
}();
// ToDo: enable splitkv when kUseSoftmax is true
if(!disable_fwd_splitkv && !kUseSoftmax &&
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
{
if constexpr(!kUseSoftmax)
{
group_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};
}
else
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};
};

View File

@@ -0,0 +1,258 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/host/hip_check_error.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_fwd_setting.hpp"
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
#include "hstu_attention_params.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
using HstuAttentionFwdTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
using HstuAttentionCombineTileSetting =
typename HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <bool kIsCrossAttention>
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
true, // kUseGroup
true, // kIsJagged
kHasBias,
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
{
const bool pad_headdim_qk =
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
constexpr bool kPadSeqLenK = true;
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
using HstuEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<OaccDataType,
OaccDataType, // keep output as OaccDataType
kPadSeqLenQ,
kPadHeadDimV,
false>>;
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
}
else
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
};
});
});
};
{
using HstuCombinePipelineProblem =
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
OaccDataType,
ODataType,
true /* kIsJagged */,
kUseSoftmax,
HstuAttentionCombineTileSetting>;
const bool pad_headdim_o =
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
using HstuTraits = ck_tile::
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
using HstuEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
ODataType,
kPadSeqLenQ,
kPadHeadDimO,
false>>;
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
HstuCombinePipelineProblem,
HstuTraits>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
});
};
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
param.num_head * param.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.num_splits,
param.num_batch / param.num_group,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.group_max_seqlen_ptr,
param.group_contextual_seqlen_ptr,
param.group_window_size_ptr,
param.group_min_full_attn_seqlen_ptr,
param.group_attn_scale_ptr,
param.hdim_qk,
param.hdim_v,
param.num_head,
param.scale_s,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,
param.seq_stride_bias,
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.num_targets_ptr,
param.p_drop,
param.philox_seed,
param.philox_offset);
}();
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.o_ptr,
param.seq_stride_o,
param.nhead_stride_o,
param.seq_q_offsets_ptr,
param.num_head,
param.num_splits,
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
};
};

View File

@@ -22,6 +22,8 @@
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
#include "hstu_attention_jagged_forward_splitkv_dispatch.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
@@ -188,11 +190,36 @@ void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGrou
MaxK,
128>::Run(param, stream);
else
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
{
const bool disable_fwd_splitkv = []() {
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
if(env_p == nullptr)
return false;
return static_cast<bool>(atoi(env_p));
}();
// ToDo: enable splitkv when kUseSoftmax is true
if(!disable_fwd_splitkv && !kUseSoftmax &&
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
{
if constexpr(!kUseSoftmax)
{
jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};
}
else
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};
};

View File

@@ -0,0 +1,261 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/host/hip_check_error.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_fwd_setting.hpp"
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
#include "hstu_attention_params.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
using HstuAttentionFwdTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <bool kIsCrossAttention>
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
false, // kUseGroup
true, // kIsJagged
kHasBias,
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
{
const bool pad_headdim_qk =
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
constexpr bool kPadSeqLenK = true;
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
using HstuEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<OaccDataType,
OaccDataType, // keep output as OaccDataType
kPadSeqLenQ,
kPadHeadDimV,
false>>;
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
}
else
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
};
});
});
};
{
using HstuCombinePipelineProblem =
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
OaccDataType,
ODataType,
true /* kIsJagged */,
kUseSoftmax,
HstuAttentionCombineTileSetting>;
const bool pad_headdim_o =
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
using HstuTraits = ck_tile::
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
using HstuEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
ODataType,
kPadSeqLenQ,
kPadHeadDimO,
false>>;
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
HstuCombinePipelineProblem,
HstuTraits>;
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
});
};
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
param.num_head * param.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.num_splits,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.max_seqlen,
param.hdim_qk,
param.hdim_v,
param.num_head,
param.scale_s,
param.attn_scale,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,
param.seq_stride_bias,
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.num_targets_ptr,
param.contextual_seqlen,
param.window_size,
param.min_full_attn_seqlen,
param.p_drop,
param.philox_seed,
param.philox_offset);
}();
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
param.num_head,
param.max_seqlen,
param.hdim_v,
param.num_splits,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.o_ptr,
param.seq_stride_o,
param.nhead_stride_o,
param.seq_q_offsets_ptr,
param.num_head,
param.num_splits,
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
};
};

View File

@@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include "hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp"
namespace ck_tile {
template <typename Problem_,
typename Traits_,
typename Policy_ = HstuAttentionFwdSplitKVCombinePipelinePolicy>
struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
{
using Problem = remove_cvref_t<Problem_>;
using Traits = remove_cvref_t<Traits_>;
using Policy = remove_cvref_t<Policy_>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM = Problem::kM;
static constexpr index_t kOHeaddim = Problem::kOHeaddim;
static_assert(kOHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
static constexpr bool kIsJagged = Problem::kIsJagged;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimO = Traits::kPadHeadDimO;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentO =
kPadHeadDimO ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Traits::kBlockPerCu != -1)
return Traits::kBlockPerCu;
else
{
return 2;
}
}();
static constexpr const char* name = "hstu_no_softmax_fwd_splitkv_combine";
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
template <typename OAccDramBlockWindowTmp, typename OAccElementFunction>
CK_TILE_DEVICE auto
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile
const OAccElementFunction& o_acc_element_func,
ck_tile::index_t split_stride,
ck_tile::index_t num_splits) const
{
static_assert(
std::is_same_v<OaccDataType, remove_cvref_t<typename OAccDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM == OAccDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kOHeaddim == OAccDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
auto o_acc_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM>{}, number<kOHeaddim>{}),
o_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakeOaccDramTileDistribution<Problem>());
auto o_acc_ptr = o_acc_dram_window.get_bottom_tensor_view().get_buffer_view().p_data_;
auto o_acc = load_tile(o_acc_dram_window);
for(int i = 1; i < num_splits; i++)
{
o_acc_dram_window.set_bottom_tensor_view_data_ptr(o_acc_ptr + split_stride * i);
auto o_acc_tile = load_tile(o_acc_dram_window);
tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile);
};
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename OAccDramBlockWindowTmp>
CK_TILE_DEVICE auto
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
ck_tile::index_t split_stride,
ck_tile::index_t num_splits) const
{
return operator()(o_acc_dram_block_window_tmp, identity{}, split_stride, num_splits);
};
};
} // namespace ck_tile

View File

@@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),

View File

@@ -65,6 +65,13 @@ struct HstuAttentionNoGroupFwdParams
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
// this need not be set by the API users, we only use it for passing num_splits between splitkv
// and combine kernel
int num_splits;
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
// calling combine kernel
void* o_acc_ptr;
};
struct HstuAttentionGroupFwdParams
@@ -125,4 +132,11 @@ struct HstuAttentionGroupFwdParams
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
// this need not be set by the API users, it only use it for passing num_splits between splitkv
// and combine kernel
int num_splits;
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
// calling combine kernel
void* o_acc_ptr;
};

View File

@@ -133,4 +133,42 @@ struct HstuAttentionFwdPipelineProblem
};
};
template <typename OaccDataType_,
typename ODataType_,
bool kIsJagged_,
bool kUseSoftmax_,
typename CombineTileSetting_>
struct HstuAttentionFwdSplitKVCombinePipelineProblem
{
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kIsJagged = kIsJagged_;
static constexpr bool kUseSoftmax = kUseSoftmax_;
static constexpr index_t kM = CombineTileSetting_::kM;
static constexpr index_t NumWarps = CombineTileSetting_::NumWarps;
static constexpr index_t kOHeaddim = CombineTileSetting_::kOHeaddim;
static constexpr index_t kSubOHeaddim = CombineTileSetting_::kSubOHeaddim;
static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size();
CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = kM;
constexpr index_t kKPerBlock = kOHeaddim;
return detail::
GetDramTileAccessMaxVectorSize<OaccDataType, kBlockSize, kMPerBlock, kKPerBlock>();
};
CK_TILE_HOST_DEVICE static constexpr auto GetODramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = kM;
constexpr index_t kKPerBlock = kOHeaddim;
return detail::
GetDramTileAccessMaxVectorSize<ODataType, kBlockSize, kMPerBlock, kKPerBlock>();
};
};
} // namespace ck_tile

View File

@@ -64,4 +64,21 @@ struct HstuAttentionFwdTileSettingClass
static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!");
};
template <index_t kM_, // tile size in seqlen_q dimension
index_t NumWarps_, // assume all warps are assigned to seqlen_q dimension
index_t kOHeaddim_>
struct HstuAttentionFwdSplitKVCombineTileSettingClass
{
static constexpr index_t kM = kM_;
static constexpr index_t NumWarps = NumWarps_;
static_assert(kM % NumWarps == 0, "Check failed!");
static constexpr index_t kOHeaddim = kOHeaddim_;
static_assert((kM * kOHeaddim) % (NumWarps * get_warp_size()) == 0, "Check failed!");
static constexpr index_t kSubOHeaddim = ceil_to_qualified_tile_length(kOHeaddim);
};
} // namespace ck_tile

View File

@@ -22,4 +22,15 @@ struct HstuAttentionFwdTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimO_ /* paddding for hdim_o */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct HstuAttentionFwdSplitKVCombineTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimO = kPadHeadDimO_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile