[CK_TILE] FA bwd kernels optimization (#1397)

* tmp save

* fix batch deterministic bugs

* fix group deterministic bugs

* codegen update

* reorder files

* bias support

* hd256 bias support

* bwd smoke test update

* simplify convert dq

* fix hd256 dropout scratch

* do{}while() -> while(){}

* comments

* remove FmhaBwdTilePartitioner

* save clear_tile

* refactor dropout

* code cleanup

* code cleanup

* comments

* fix epilogue problem

* fix fwd dropout

* group convert_dq opt

* fix dq alignment

* Do not store storerandval in bwd for flash attention integration

* fix hd32 error and boost performance

* revert

* Remove duplicated WarpGemm definitions in the policy file

* dropout patch for mrepeat 16*16

* code sync up

* dq_acc stride

* dq_acc stride stuff

* codegen update

* fwd dropout revert

* fix hd128 scratches and boost performance

* receipt 3 for simplified smoke test

* more strides for fa integration

* fix hd64 scratches and boost performance

* non-iglp pipeline for headdim padding cases

* dpad same as dvpad for flash attention integration

* unpadded lse&d for group mode

* Support unpad layout for group lse

* Support unpad lse layout for splitkv

* Fix stride for splitkv kernel

* fix unpadded lse issue in fwd splitkv

* comment

* solve lds read&write conflicts

* rename

* bias rename

* tile index revert

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
Dan Yao
2024-08-17 04:40:10 +08:00
committed by GitHub
parent 2581727d2a
commit 79a5d9c10c
43 changed files with 5515 additions and 4222 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,54 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile

View File

@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
batch_offset_lse = query_start;
}
if constexpr(kHasDropout)
{

View File

@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) +
(kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE)
{
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;