mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit 'b6e899438631118ff962a6be12cabc9930366267' into develop
This commit is contained in:
@@ -763,15 +763,21 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
dq_acc_buf.ToDevice(dq_acc_host.data());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
// non-deterministic kernels use atomic add to write dq
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases thus we need to zero out it first
|
||||
if(!deterministic || mask.type == mask_enum::no_mask)
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
|
||||
@@ -1767,6 +1767,9 @@ struct FmhaFwdKernel
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
constexpr auto kDramTileK =
|
||||
FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
@@ -1835,32 +1838,36 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const auto k_dram_unmerged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(
|
||||
make_pass_through_transform(height),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}));
|
||||
|
||||
const auto k_dram_permuted = transform_tensor_view(
|
||||
k_dram_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(
|
||||
height,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
|
||||
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{}),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
k_dram_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -37,6 +37,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
|
||||
@@ -16,6 +16,6 @@ const auto HDimValues =
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
|
||||
@@ -16,6 +16,6 @@ const auto HDimValues =
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
|
||||
Reference in New Issue
Block a user