mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -99,8 +100,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
|
||||
static constexpr const char* name = "qs";
|
||||
|
||||
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
using DropoutType = int32_t; // unused
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -267,7 +267,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -620,10 +621,46 @@ struct BlockFmhaPipelineQSKSVS
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
// template <typename QDramBlockWindowTmp,
|
||||
// typename KDramBlockWindowTmp,
|
||||
// typename VDramBlockWindowTmp,
|
||||
// typename BiasDramBlockWindowTmp,
|
||||
// typename LSEDramBlockWindowTmp,
|
||||
// typename PositionEncoding>
|
||||
// CK_TILE_HOST_DEVICE auto
|
||||
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
// FmhaMask mask,
|
||||
// PositionEncoding position_encoding,
|
||||
// float scale_s,
|
||||
// void* smem_ptr) const
|
||||
// {
|
||||
// return operator()(q_dram_block_window_tmp,
|
||||
// identity{},
|
||||
// k_dram_block_window_tmp,
|
||||
// identity{},
|
||||
// v_dram_block_window_tmp,
|
||||
// identity{},
|
||||
// bias_dram_block_window_tmp,
|
||||
// identity{},
|
||||
// lse_dram_block_window_tmp,
|
||||
// identity{},
|
||||
// identity{},
|
||||
// identity{},
|
||||
// identity{},
|
||||
// mask,
|
||||
// position_encoding,
|
||||
// scale_s,
|
||||
// smem_ptr);
|
||||
// }
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
@@ -631,11 +668,13 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -645,6 +684,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
@@ -653,7 +693,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user