mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Add support of loading QK tiles of hdim96 without padding to hdim128
This commit is contained in:
@@ -59,7 +59,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
|
||||
@@ -573,7 +573,7 @@ struct HstuAttentionFwdKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
@@ -586,7 +586,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
@@ -624,14 +624,14 @@ struct HstuAttentionFwdKernel
|
||||
make_tile_window(q_dram,
|
||||
[&]() {
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{});
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
number<HstuAttentionPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
|
||||
@@ -126,7 +126,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
return Problem::template GetDramTileAccessMaxVectorSize<QDataType,
|
||||
kBlockSize,
|
||||
@@ -193,11 +193,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160
|
||||
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
{
|
||||
return kKPerBlock * kNPerBlock;
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -244,11 +249,20 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::HstuAttentionTileSetting::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKVector>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -331,25 +345,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -357,25 +402,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -383,11 +459,36 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
{
|
||||
constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKVector>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -500,24 +601,52 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -58,7 +58,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
|
||||
@@ -37,7 +37,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
|
||||
|
||||
@@ -53,7 +53,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -127,9 +127,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
@@ -151,8 +151,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -184,7 +183,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
@@ -194,7 +193,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -226,7 +225,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
@@ -241,25 +239,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -382,7 +370,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[i_n0], k_tiles[i_n0], partition_index);
|
||||
store_tile(k_lds_windows[i_n0], k_tiles[i_n0], partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -395,7 +383,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -535,7 +523,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
|
||||
@@ -55,7 +55,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -126,9 +126,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
@@ -150,8 +150,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -183,7 +182,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
@@ -193,7 +192,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -219,7 +218,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
@@ -234,25 +232,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -351,7 +339,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[i_n0],
|
||||
partition_index);
|
||||
|
||||
@@ -366,7 +354,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -481,8 +469,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
|
||||
@@ -100,7 +100,7 @@ struct HstuAttentionFwdPipelineProblem
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
}
|
||||
@@ -108,7 +108,7 @@ struct HstuAttentionFwdPipelineProblem
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ struct HstuAttentionFwdTileSettingClass
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size along k seqlen
|
||||
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
@@ -60,6 +60,8 @@ struct HstuAttentionFwdTileSettingClass
|
||||
// once (or repeately load Q as a whole tile)
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
|
||||
static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!");
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -37,7 +37,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax");
|
||||
|
||||
@@ -53,7 +53,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -127,9 +127,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
@@ -153,8 +153,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -198,7 +197,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
@@ -208,7 +207,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -245,7 +244,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
// when kQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
@@ -260,25 +259,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -401,7 +390,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
@@ -428,7 +417,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -673,7 +662,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
|
||||
@@ -55,7 +55,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -126,9 +126,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
@@ -152,8 +152,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -197,7 +196,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
@@ -207,7 +206,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -241,7 +240,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
@@ -255,25 +253,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -372,7 +360,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
@@ -399,7 +387,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -638,8 +626,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
|
||||
Reference in New Issue
Block a user