mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Change in fwd-splitkv kernel to support num_splits=1 case (#1690)
* Change in fwd-splitkv kernel to support num_splits=1 case * Update in codegen fwd-splitkv to make num_splits > 1 cases pass * Specify instance traits in dispatch * Fix link error for fp8 kernels --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = true; // always store LSE (acc)
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
|
||||
Reference in New Issue
Block a user