Change in fwd-splitkv kernel to support num_splits=1 case (#1690)

* Change in fwd-splitkv kernel to support num_splits=1 case

* Update in codegen fwd-splitkv to make num_splits > 1 cases pass

* Specify instance traits in dispatch

* Fix link error for fp8 kernels

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Qianfeng
2024-11-25 12:31:38 +08:00
committed by GitHub
parent 19d4b79039
commit ce2bdf42a9
4 changed files with 41 additions and 26 deletions

View File

@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;

View File

@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasUnevenSplits_,