mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add implementation of fwd splitkv on no_softmax path
This commit is contained in:
@@ -22,6 +22,8 @@
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
#include "hstu_attention_batched_forward_splitkv_dispatch.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
@@ -190,7 +192,7 @@ template <typename InOutDataType,
|
||||
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.seqlen_q) == 128)
|
||||
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
@@ -199,11 +201,37 @@ void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGro
|
||||
MaxK,
|
||||
128>::Run(param, stream);
|
||||
else
|
||||
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
{
|
||||
const bool disable_fwd_splitkv = []() {
|
||||
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
|
||||
if(env_p == nullptr)
|
||||
return false;
|
||||
return static_cast<bool>(atoi(env_p));
|
||||
}();
|
||||
|
||||
// ToDo: enable splitkv when kUseSoftmax is true
|
||||
if(!disable_fwd_splitkv && !kUseSoftmax &&
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.seqlen_q))
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
batched_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param,
|
||||
stream);
|
||||
};
|
||||
}
|
||||
else
|
||||
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK,
|
||||
ck_tile::index_t MTile>
|
||||
struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
|
||||
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
|
||||
|
||||
using HstuAttentionFwdTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
false, // kUseGroup
|
||||
false, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
|
||||
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
{
|
||||
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionFwdTileSetting::kN0 == 0);
|
||||
const bool pad_headdim_qk =
|
||||
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH_3(
|
||||
pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
OaccDataType,
|
||||
OaccDataType, // keep output as OaccDataType
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV,
|
||||
false>>;
|
||||
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
{
|
||||
using HstuCombinePipelineProblem =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
false /* kIsJagged */,
|
||||
kUseSoftmax,
|
||||
HstuAttentionCombineTileSetting>;
|
||||
const bool pad_headdim_o =
|
||||
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
|
||||
using HstuTraits = ck_tile::
|
||||
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
|
||||
ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimO,
|
||||
false>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
|
||||
HstuCombinePipelineProblem,
|
||||
HstuTraits>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.num_splits,
|
||||
param.seqlen_q,
|
||||
param.is_cross_attention ? param.seqlen_kv
|
||||
: param.seqlen_q,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.attn_scale,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.batch_stride_q,
|
||||
param.batch_stride_k,
|
||||
param.batch_stride_v,
|
||||
param.batch_stride_bias,
|
||||
param.num_targets_ptr,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
|
||||
param.num_head,
|
||||
param.seqlen_q,
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.batch_stride_o,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seqlen_q,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
};
|
||||
};
|
||||
@@ -484,3 +484,36 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_
|
||||
// mtile-64 can be added through tuning/verification
|
||||
return 64;
|
||||
};
|
||||
|
||||
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int num_CUs = get_number_of_cu();
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
|
||||
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
|
||||
|
||||
// assume each CU can run two work-groups, common cases for hdim128
|
||||
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
|
||||
};
|
||||
|
||||
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
// Please tune the threshold here
|
||||
const float threshold = 0.8f;
|
||||
|
||||
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int i = 2;
|
||||
|
||||
// Please tune the threshold here
|
||||
const float threshold = 1.5f;
|
||||
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
|
||||
i++;
|
||||
|
||||
return i;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
|
||||
#endif
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename HstuAttentionPipeline_, typename EpiloguePipeline_>
|
||||
struct HstuAttentionFwdSplitKVCombineKernel
|
||||
{
|
||||
using HstuAttentionPipeline = ck_tile::remove_cvref_t<HstuAttentionPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
|
||||
using OaccDataType =
|
||||
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::OaccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::ODataType>;
|
||||
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
struct HstuAttentionCombineEmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct HstuAttentionBatchedCombineBaseKargs
|
||||
{
|
||||
const void* o_acc_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t num_head;
|
||||
ck_tile::index_t num_splits;
|
||||
ck_tile::index_t hdim_v;
|
||||
};
|
||||
|
||||
struct HstuAttentionJaggedCombineBaseKargs
|
||||
{
|
||||
const void* o_acc_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
ck_tile::index_t num_head;
|
||||
ck_tile::index_t num_splits;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
};
|
||||
|
||||
struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs
|
||||
{
|
||||
};
|
||||
|
||||
struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs
|
||||
{
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
conditional_t<kIsJagged, HstuAttentionJaggedCombineKargs, HstuAttentionBatchedCombineKargs>;
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t num_head,
|
||||
ck_tile::index_t num_splits, // number of splitted seqlen_kv
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
Kargs kargs{o_acc_ptr,
|
||||
o_ptr,
|
||||
batch_stride_o,
|
||||
seq_stride_o,
|
||||
nhead_stride_o,
|
||||
seqlen_q,
|
||||
num_head,
|
||||
num_splits,
|
||||
hdim_v};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* seq_q_offsets_ptr,
|
||||
ck_tile::index_t num_head,
|
||||
ck_tile::index_t num_splits, // number of splitted seqlen_kv
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
Kargs kargs{o_acc_ptr,
|
||||
o_ptr,
|
||||
seq_stride_o,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
num_head,
|
||||
num_splits,
|
||||
hdim_v,
|
||||
0 /* seqlen_q will be updated later */};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_)
|
||||
{
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM);
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
return dim3(batch_size_, nhead_, num_tile_in_seqlen);
|
||||
#else
|
||||
return dim3(num_tile_in_seqlen),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
ignore = kargs;
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
const index_t i_batch = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_tile_m = blockIdx.z;
|
||||
#else
|
||||
const index_t i_tile_m = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
|
||||
|
||||
// assume o_acc is in compact shape of [batch_size, max_seqlen, num_head, num_splits,
|
||||
// hdim]
|
||||
batch_offset_o_acc = query_start * kargs.num_head * kargs.num_splits * kargs.hdim_v;
|
||||
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
|
||||
kargs.seqlen_q =
|
||||
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// assume o_acc is in compact shape of [batch_size, seqlen_q, num_head, num_splits,
|
||||
// hdim]
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.seqlen_q *
|
||||
kargs.num_head * kargs.num_splits * kargs.hdim_v;
|
||||
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
index_t i_m0;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
|
||||
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
return;
|
||||
|
||||
// assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim]
|
||||
const OaccDataType* o_acc_ptr =
|
||||
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.num_splits * kargs.hdim_v +
|
||||
batch_offset_o_acc;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
// Oacc DRAM and Oacc DRAM window
|
||||
auto seq_stride_o_acc = kargs.num_head * kargs.num_splits * kargs.hdim_v;
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(seq_stride_o_acc, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
const auto o_acc_dram =
|
||||
pad_tensor_view(o_acc_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
sequence<false, kPadHeadDimO>{});
|
||||
|
||||
auto o_acc_dram_window =
|
||||
make_tile_window(o_acc_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return HstuAttentionPipeline{}(o_acc_dram_window, kargs.hdim_v, kargs.num_splits);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_o, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(o_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
sequence<false, kPadHeadDimO>{});
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct HstuAttentionFwdSplitKVCombinePipelinePolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::kM;
|
||||
constexpr index_t kKPerBlock = Problem::kOHeaddim;
|
||||
constexpr index_t NumWarps = Problem::NumWarps;
|
||||
|
||||
constexpr index_t KVector = GetAlignmentOacc<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / KVector;
|
||||
|
||||
if constexpr(kKPerBlock == Problem::kSubOHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t KPerThread = KVector;
|
||||
|
||||
// try to assign more consecutive threads on dim-K
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
static_assert(KThreads <= get_warp_size(), "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ensure KThreads be power-of-2 integer
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, KVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
|
||||
{
|
||||
return Problem::GetOaccDramTileAccessMaxVectorSize();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
// should be same as GetAlignmentOacc() since o_tile will use the same encoding as
|
||||
// o_acc_tile
|
||||
return GetAlignmentOacc<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_tile_setting_define.hpp"
|
||||
|
||||
template <ck_tile::index_t kOHeaddim>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<32, 4, 64>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSetting<96>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 96>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 128>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<8, 4, 256>;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,8 @@
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
#include "hstu_attention_group_forward_splitkv_dispatch.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
@@ -185,11 +187,36 @@ void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFw
|
||||
MaxK,
|
||||
128>::Run(param, stream);
|
||||
else
|
||||
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
{
|
||||
const bool disable_fwd_splitkv = []() {
|
||||
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
|
||||
if(env_p == nullptr)
|
||||
return false;
|
||||
return static_cast<bool>(atoi(env_p));
|
||||
}();
|
||||
|
||||
// ToDo: enable splitkv when kUseSoftmax is true
|
||||
if(!disable_fwd_splitkv && !kUseSoftmax &&
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
group_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
};
|
||||
}
|
||||
else
|
||||
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK,
|
||||
ck_tile::index_t MTile>
|
||||
struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
|
||||
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
|
||||
|
||||
using HstuAttentionFwdTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
using HstuAttentionCombineTileSetting =
|
||||
typename HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
true, // kUseGroup
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
|
||||
|
||||
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
{
|
||||
const bool pad_headdim_qk =
|
||||
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
constexpr bool kPadSeqLenK = true;
|
||||
|
||||
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<OaccDataType,
|
||||
OaccDataType, // keep output as OaccDataType
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV,
|
||||
false>>;
|
||||
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
{
|
||||
using HstuCombinePipelineProblem =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
true /* kIsJagged */,
|
||||
kUseSoftmax,
|
||||
HstuAttentionCombineTileSetting>;
|
||||
const bool pad_headdim_o =
|
||||
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
|
||||
using HstuTraits = ck_tile::
|
||||
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
|
||||
ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimO,
|
||||
false>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
|
||||
HstuCombinePipelineProblem,
|
||||
HstuTraits>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.num_splits,
|
||||
param.num_batch / param.num_group,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.group_max_seqlen_ptr,
|
||||
param.group_contextual_seqlen_ptr,
|
||||
param.group_window_size_ptr,
|
||||
param.group_min_full_attn_seqlen_ptr,
|
||||
param.group_attn_scale_ptr,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.num_targets_ptr,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
};
|
||||
};
|
||||
@@ -22,6 +22,8 @@
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
#include "hstu_attention_jagged_forward_splitkv_dispatch.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
@@ -188,11 +190,36 @@ void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGrou
|
||||
MaxK,
|
||||
128>::Run(param, stream);
|
||||
else
|
||||
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
{
|
||||
const bool disable_fwd_splitkv = []() {
|
||||
const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV");
|
||||
if(env_p == nullptr)
|
||||
return false;
|
||||
return static_cast<bool>(atoi(env_p));
|
||||
}();
|
||||
|
||||
// ToDo: enable splitkv when kUseSoftmax is true
|
||||
if(!disable_fwd_splitkv && !kUseSoftmax &&
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
};
|
||||
}
|
||||
else
|
||||
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK,
|
||||
64>::Run(param, stream);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK,
|
||||
ck_tile::index_t MTile>
|
||||
struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!");
|
||||
static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!");
|
||||
|
||||
using HstuAttentionFwdTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
false, // kUseGroup
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
using ODataType = HstuAttentionFwdTypeConfig<InOutDataType>::ODataType;
|
||||
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
{
|
||||
const bool pad_headdim_qk =
|
||||
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
constexpr bool kPadSeqLenK = true;
|
||||
|
||||
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<OaccDataType,
|
||||
OaccDataType, // keep output as OaccDataType
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV,
|
||||
false>>;
|
||||
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
{
|
||||
using HstuCombinePipelineProblem =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem<
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
true /* kIsJagged */,
|
||||
kUseSoftmax,
|
||||
HstuAttentionCombineTileSetting>;
|
||||
const bool pad_headdim_o =
|
||||
!(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] {
|
||||
using HstuTraits = ck_tile::
|
||||
HstuAttentionFwdSplitKVCombineTraits<kPadSeqLenQ, kPadHeadDimO, occupancy>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<OaccDataType,
|
||||
ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimO,
|
||||
false>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline<
|
||||
HstuCombinePipelineProblem,
|
||||
HstuTraits>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.num_splits,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.attn_scale,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.num_targets_ptr,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
|
||||
param.num_head,
|
||||
param.max_seqlen,
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
};
|
||||
};
|
||||
@@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
};
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
#include "hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_,
|
||||
typename Traits_,
|
||||
typename Policy_ = HstuAttentionFwdSplitKVCombinePipelinePolicy>
|
||||
struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM = Problem::kM;
|
||||
static constexpr index_t kOHeaddim = Problem::kOHeaddim;
|
||||
|
||||
static_assert(kOHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimO = Traits::kPadHeadDimO;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimO ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
return Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "hstu_no_softmax_fwd_splitkv_combine";
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename OAccDramBlockWindowTmp, typename OAccElementFunction>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
ck_tile::index_t split_stride,
|
||||
ck_tile::index_t num_splits) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<OaccDataType, remove_cvref_t<typename OAccDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM == OAccDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kOHeaddim == OAccDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
auto o_acc_dram_window =
|
||||
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM>{}, number<kOHeaddim>{}),
|
||||
o_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeOaccDramTileDistribution<Problem>());
|
||||
|
||||
auto o_acc_ptr = o_acc_dram_window.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
auto o_acc = load_tile(o_acc_dram_window);
|
||||
|
||||
for(int i = 1; i < num_splits; i++)
|
||||
{
|
||||
o_acc_dram_window.set_bottom_tensor_view_data_ptr(o_acc_ptr + split_stride * i);
|
||||
auto o_acc_tile = load_tile(o_acc_dram_window);
|
||||
|
||||
tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile);
|
||||
};
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename OAccDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
ck_tile::index_t split_stride,
|
||||
ck_tile::index_t num_splits) const
|
||||
{
|
||||
return operator()(o_acc_dram_block_window_tmp, identity{}, split_stride, num_splits);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
};
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
|
||||
@@ -65,6 +65,13 @@ struct HstuAttentionNoGroupFwdParams
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
// this need not be set by the API users, we only use it for passing num_splits between splitkv
|
||||
// and combine kernel
|
||||
int num_splits;
|
||||
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
|
||||
// calling combine kernel
|
||||
void* o_acc_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionGroupFwdParams
|
||||
@@ -125,4 +132,11 @@ struct HstuAttentionGroupFwdParams
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
// this need not be set by the API users, it only use it for passing num_splits between splitkv
|
||||
// and combine kernel
|
||||
int num_splits;
|
||||
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
|
||||
// calling combine kernel
|
||||
void* o_acc_ptr;
|
||||
};
|
||||
|
||||
@@ -133,4 +133,42 @@ struct HstuAttentionFwdPipelineProblem
|
||||
};
|
||||
};
|
||||
|
||||
template <typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
bool kIsJagged_,
|
||||
bool kUseSoftmax_,
|
||||
typename CombineTileSetting_>
|
||||
struct HstuAttentionFwdSplitKVCombinePipelineProblem
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
|
||||
static constexpr bool kIsJagged = kIsJagged_;
|
||||
static constexpr bool kUseSoftmax = kUseSoftmax_;
|
||||
|
||||
static constexpr index_t kM = CombineTileSetting_::kM;
|
||||
static constexpr index_t NumWarps = CombineTileSetting_::NumWarps;
|
||||
static constexpr index_t kOHeaddim = CombineTileSetting_::kOHeaddim;
|
||||
static constexpr index_t kSubOHeaddim = CombineTileSetting_::kSubOHeaddim;
|
||||
static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = kM;
|
||||
constexpr index_t kKPerBlock = kOHeaddim;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<OaccDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetODramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = kM;
|
||||
constexpr index_t kKPerBlock = kOHeaddim;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<ODataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -64,4 +64,21 @@ struct HstuAttentionFwdTileSettingClass
|
||||
static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!");
|
||||
};
|
||||
|
||||
template <index_t kM_, // tile size in seqlen_q dimension
|
||||
index_t NumWarps_, // assume all warps are assigned to seqlen_q dimension
|
||||
index_t kOHeaddim_>
|
||||
struct HstuAttentionFwdSplitKVCombineTileSettingClass
|
||||
{
|
||||
static constexpr index_t kM = kM_;
|
||||
static constexpr index_t NumWarps = NumWarps_;
|
||||
|
||||
static_assert(kM % NumWarps == 0, "Check failed!");
|
||||
|
||||
static constexpr index_t kOHeaddim = kOHeaddim_;
|
||||
|
||||
static_assert((kM * kOHeaddim) % (NumWarps * get_warp_size()) == 0, "Check failed!");
|
||||
|
||||
static constexpr index_t kSubOHeaddim = ceil_to_qualified_tile_length(kOHeaddim);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -22,4 +22,15 @@ struct HstuAttentionFwdTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimO_ /* paddding for hdim_o */,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct HstuAttentionFwdSplitKVCombineTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimO = kPadHeadDimO_;
|
||||
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user