mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Use separate pipelines for using or not-using softmax situations
This commit is contained in:
@@ -15,8 +15,10 @@
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
@@ -66,38 +68,52 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH_3(pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
BOOL_SWITCH_3(
|
||||
pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
using HstuEpilogue =
|
||||
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseTrLoad,
|
||||
ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
@@ -82,13 +84,28 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseTrLoad,
|
||||
ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
|
||||
#include "hstu_attention_fwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::CompDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
|
||||
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
return 2;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_hstu";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType,
|
||||
remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (one + exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK1>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
{
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
|
||||
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}(
|
||||
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,540 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
|
||||
#include "hstu_attention_fwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::CompDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
|
||||
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
return 2;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_hstu";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType,
|
||||
remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (one + exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK1>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
{
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[number<i_k1>{}]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 + 2>{}]);
|
||||
});
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionFwdPipelineQRKSVS
|
||||
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
@@ -35,6 +35,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
@@ -143,6 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
ignore = scale_p;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -160,8 +163,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr bool kUseSoftmax = Problem::kUseSoftmax;
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
@@ -293,20 +294,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (one + exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
@@ -381,11 +368,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
@@ -454,129 +438,98 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if constexpr(!kUseSoftmax)
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
{
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
};
|
||||
});
|
||||
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) { x = x * type_convert<CompDataType>(scale_p); }, pcomp_tile);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
auto rowsum_p =
|
||||
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
const auto m_old = m;
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
@@ -635,22 +588,19 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
};
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
@@ -35,6 +35,8 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
@@ -143,6 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
ignore = scale_p;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -160,8 +163,6 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr bool kUseSoftmax = Problem::kUseSoftmax;
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
@@ -293,20 +294,6 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (one + exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
@@ -381,11 +368,8 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
@@ -454,129 +438,98 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if constexpr(!kUseSoftmax)
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
{
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
};
|
||||
});
|
||||
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) { x = x * type_convert<CompDataType>(scale_p); }, pcomp_tile);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
auto rowsum_p =
|
||||
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
const auto m_old = m;
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
@@ -622,22 +575,19 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad
|
||||
});
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
};
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
Reference in New Issue
Block a user