mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Support WMMA (gfx12) in FMHA (#2528)
* Pass hdim to tile_example_fmha_fwd in fp8 tests
* Add WMMA support to fwd FMHA pipelines
* Tune tile sizes a bit for less spilling
fp16 256 is still quite slow
* Fix Q grad tile distribution for warp size = 32 and hdim >= 256
With AccDataType = float and warp size = 32, K0 becomes 0, K repeat is required to correcty distribute the tile.
* Use code based on BlockDropout in BlockDropoutBwd
* Fix split KV combine kernel for gfx12 (warp size 32) and make it more universal
* Fix LSE LDS tensor descriptors: kMaxSplits and kM0 were swapped, it worked on gfx9
because they both equal to 8 while on gfx12 they are 8 and 4;
* Fix Oacc LDS tensor descriptor: it was transposed even though its shape=[4 * kM0, kN1],
it worked on gfx9 because 4 * kM == kN1 == 32;
* Removing these hidden dependecies allows to support:
* any number of warps (power-of-2), not only 4;
* kN1 = 16, not only 32;
* any number of splits;
* Rename ids like o_acc_4 and Oacc4 to eliminate confusion: kNumWarps doesn't have to be 4 now
* Replace hard-coded kN1 in dispatch code with the requested tile size
* Add gfx12-specific tile sizes for split KV
* Pass GPU architecture to kernel generation scripts
This is still a temporary solution.
* Build and run FMHA CI tests for gfx12
* Fix issue after merging
* Fix bwd tile sizes
The current pipelines always read only one tile K and V tile, this
requires bk0 == bhdq and bk2 == bhdv (kK0 == kQKHeaddim and
kK2 == kVHeaddim).
* Use hardware f32->f8 on gfx12, remove v_perm
__builtin_amdgcn_perm is not needed because
__builtin_amdgcn_cvt_pk_fp8_f32 allows to specify which word (16 bit of
32-bit dword) is used to store results (two f8 values).
* Update changelog
* Add WMMA support to pagedkv
* Fix scripts after rebasing
* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout
Add comments with dropout implementation details
Fix performance regression of fwd+dropout
* Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
* "scalarize" seed and offset, they may come either from kernel args or from device memory
(presumably loaded with vector loads).
These changes help the compiler to procude more optimal code and reduce register spilling.
Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding
Use code based on BlockDropout in BlockDropoutBwd
Refactor BlockDropout (fwd)
Implement BlockDropout (fwd) for WMMA
Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
this version supports 16x16 tiles.
If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
to BlockDropoutBwd.
Implement BlockDropoutBwd for WMMA
Remove MakeRandValLds* functions unused in BlockDropoutBwd
Remove unused Run overload from BlockDropoutBwd
* Fix regression with philox seed and offset when they exceed 32-bit int
__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.
* Fix names after cherry-picking
* Fix selection of a fallback tile based on bm0
The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.
* Do not use filters related to qr_async_trload
They disable tiles/pipelines which are valid for gfx12.
* Use different dstr encoding when C is transposed
* Do not call GetQKBlockGemm (and hence WarpGemmDispatcher) in host code
Some WarpGemmDispatcher instantiations are defined only
for specific archs and undefined on host.
Calculations related to sched barriers are moved from Pipeline's public
fields into pipeline's operator().
* Fix incorrect name WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
Correct name is WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
because it's 32x32x16 with IterateK = 2 so K = 32, also all tiles used
in codegen scripts are 32, 32, 32.
* Generalize usages of WarpGemmDispatcher for MFMA and WMMA
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution is still
used explicitly becaus of swizzle factor = 4.
* Mark has_load_tr as maybe_unused
There are no transpose loading for RDNA.
* Remove CK_TILE_USE_MFMA/WMMA from fmha-related code
* Detect BlockSize on host based on warp size of the current device
If kBlockSize == kNumWarps * get_warp_size(), the kernel is launched with
kBlockSize / 2 because on host get_warp_size() == 64 always.
* Fix calculation of grid size for combine kernel with warp size = 32
* Add missing includes and header
* Support multiple archs in one binary for fwd
* Support multiple archs in one binary for fwd_splitkv, fwd_appendkv, pagedkv_prefill
* Support multiple archs in one binary for bwd
* trload kernels are compiled only for gfx950;
* instances with padding are checked after instances without padding so
they can be used as fallbacks (similarly to fwd);
* Extract common code from register_traits
* Revert "Fix regression with philox seed and offset when they exceed 32-bit int"
To simplify merging , the proper fix is in develop already.
* Support new numerical d paddings in trait ordering checks
* Build fp32 tests only on gfx9
* Do not use hardcoded M0 = 64 for dot bwd kernel
* Use textwrap.indent from standard library
* Make fp8 pipelines on gfx12 consistent with gfx9
* Update tests for current pipelines
* Make ninja check more responsive in CI
ninja buffers output so this job looks hanging.
* Support fp8fp32 by limiting O vector size
The fp32 output type requires storing 8 * sizeof(float) = 32 bytes,
which is not implemented (here 8 is the number of C values per lane for
v_wmma_f32_16x16x16...).
* Remove unused cmake options
* Unify including amd_buffer_addressing.hpp/_builtins.hpp
* Temporarily use amd_buffer_addressing.hpp on >=gfx10
amd_buffer_addressing_builtins.hpp uses inline asm for loads/stores
which is not compatible with >=gfx10:
* 1 scalar for exec masks instead of 2,
* gfx12 uses different instruction names etc.
* Update asm in bf16 conversions to work with warp 32
* Do not generate splitkv/appendkv with vlayout=col for consistency with fwd
* Add arch tags to kernels/host funcs, compile for each arch separately
* Add kM0 to fmha_bwd_dot_do_o kernel name to match filename
* Add workaround for miscompilation of bwd with padded hdim
SWDEV-559729: v_wmma instructions can be incorrectly placed in divergent
branches used to store padded tensors (when some lanes are inactive due
to padding). Inline asm with dummy dependencies on VGPRs of the tensors
prevents the compiler doing this.
* Fix add_gtest_executable for absolute paths
Some tests (like gemm_tile_engine) pass absolute paths to source files.
In CI the branch name is a part of the root dir, and if the branch name
contains "wmma", "xdl" etc., files can be incorrectly excluded.
* Run only hdim 128 smoke tests for fp8fp32
There are no instances for hdim 64 and 256.
* Format py with ruff to simplify merging develop
* Fix incorrect var name
* Codegen for gfx9,gfx950 when --targets is not specified
Aiter and Pytorch require changes for passing their targets to the codegen scripts.
With this temporary solution the files are generated but not all of them
have to be really built (depending on the used --offload-arch=).
* Combine arch-related values into ArchTrait
This more centralized approach removes duplication of various formatting templates.
* Try a workaround for Jenkins error "groovyjarjarasm.asm.MethodTooLargeException: Method too large"
Some code is extracted into a function.
This commit is contained in:
@@ -692,7 +692,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -677,7 +677,17 @@ struct FmhaBwdDQDKDVKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -1171,6 +1181,21 @@ struct FmhaBwdDQDKDVKernel
|
||||
scale_rp_undrop,
|
||||
dropout);
|
||||
|
||||
#if defined(__gfx12__)
|
||||
// Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
|
||||
// placed in divergent branches used to store padded tensors (when some lanes are
|
||||
// inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors
|
||||
// prevents the compiler doing this.
|
||||
if constexpr(kPadHeadDimQ > 0)
|
||||
{
|
||||
impl::insert_dummy_dep(dk_acc_tile.get_thread_buffer());
|
||||
}
|
||||
if constexpr(kPadHeadDimV > 0)
|
||||
{
|
||||
impl::insert_dummy_dep(dv_acc_tile.get_thread_buffer());
|
||||
}
|
||||
#endif
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
|
||||
}
|
||||
@@ -1241,7 +1266,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
"_b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -1371,7 +1396,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
@@ -1678,7 +1703,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ struct FmhaFwdAppendKVKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -143,41 +143,41 @@ struct FmhaFwdAppendKVKernel
|
||||
{
|
||||
};
|
||||
|
||||
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool has_mask,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool has_mask,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
{
|
||||
Kargs kargs{
|
||||
{q_ptr,
|
||||
@@ -255,7 +255,7 @@ struct FmhaFwdAppendKVKernel
|
||||
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
|
||||
@@ -1079,7 +1079,17 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -865,7 +865,17 @@ struct FmhaFwdPagedKVKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -37,7 +37,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -127,7 +127,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
@@ -185,7 +185,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
@@ -240,8 +240,10 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// Recalculate kM0 = get_warp_size() / NThreads on host
|
||||
const index_t m0 = (is_wave32() ? 32 : 64) / FmhaPipeline::Problem::NThreads;
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, m0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
@@ -266,7 +268,17 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -344,7 +356,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q),
|
||||
make_tuple(kargs.split_stride_lse_acc, 1),
|
||||
make_tuple(kargs.split_stride_lse_acc, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentLSEacc>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -358,11 +370,11 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
|
||||
// read kNumWarps * (kM0, kN1) o_acc tiles simultaneously by kNumWarps warps
|
||||
const auto o_acc_dram_view = pad_tensor_view(
|
||||
o_acc_dram_naive,
|
||||
make_tuple(
|
||||
@@ -469,7 +481,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.row_stride_o, 1),
|
||||
make_tuple(kargs.row_stride_o, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ struct FmhaFwdSplitKVKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -279,7 +279,7 @@ struct FmhaFwdSplitKVKernel
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
@@ -409,7 +409,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
@@ -574,7 +574,17 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
@@ -683,26 +682,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(AccDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t K2 = GetAlignmentPostQGradAcc<Problem>();
|
||||
constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M2 = get_warp_size() / K1;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<2>, sequence<2, 3>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2, 3>,
|
||||
sequence<0, 0, 1>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 3>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 3, 3>,
|
||||
sequence<0, 0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
@@ -711,27 +710,25 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(AccDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t K2 = GetAlignmentPostQGrad<Problem>();
|
||||
constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M2 = get_warp_size() / K1;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
|
||||
@@ -31,59 +31,33 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
|
||||
@@ -273,16 +273,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
|
||||
auto o_acc_4_dram_window =
|
||||
// First each warp processes its own part of splits
|
||||
|
||||
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
|
||||
auto o_acc_dram_window =
|
||||
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
o_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
o_acc_dram_block_window_tmp.get_window_origin(),
|
||||
o_acc_4_dist);
|
||||
o_acc_dist);
|
||||
|
||||
// shape=[4 * KM0, kN1]
|
||||
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
|
||||
clear_tile(o_acc_4);
|
||||
// shape=[kNumWarps * KM0, kN1]
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
|
||||
|
||||
@@ -291,73 +293,73 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
// each warp handles a [KM0, kN1] tile
|
||||
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
|
||||
{
|
||||
auto o_tile = load_tile(o_acc_4_dram_window);
|
||||
auto o_tile = load_tile(o_acc_dram_window);
|
||||
const index_t i_split = split_start + get_warp_id();
|
||||
const index_t row_start = kM0 * get_warp_id();
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
o_acc_4.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
|
||||
o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
|
||||
}
|
||||
|
||||
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
|
||||
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
|
||||
|
||||
{
|
||||
auto o_acc_4_lds_window = [&]() {
|
||||
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0});
|
||||
}();
|
||||
store_tile(o_acc_4_lds_window, o_acc_4);
|
||||
}
|
||||
|
||||
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
|
||||
|
||||
auto o_acc_4_lds_window = [&]() {
|
||||
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
|
||||
}();
|
||||
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
static_for<0, kNumWarps, 1>{}([&](auto) {
|
||||
auto o_acc_in = load_tile(o_acc_4_lds_window);
|
||||
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) += o_acc_in(i_j_idx);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
|
||||
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_4_lds_window, {kM0, 0});
|
||||
move_tile_window(o_acc_dram_window, {kNumWarps * kM0, 0});
|
||||
}
|
||||
|
||||
// Then each warps combines partial o_acc results into one
|
||||
|
||||
// kNumWarps o_acc tiles in LDS. shape=[kNumWarps * kM0, kN1]
|
||||
OaccDataType* o_acc_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
|
||||
|
||||
{
|
||||
auto o_acc_lds_store_window = [&]() {
|
||||
auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0});
|
||||
}();
|
||||
store_tile(o_acc_lds_store_window, o_acc);
|
||||
}
|
||||
|
||||
auto o_acc_result_dist = Policy::template MakeOaccResultDramTileDistribution<Problem>();
|
||||
|
||||
auto o_acc_lds_load_window = [&]() {
|
||||
auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_result_dist);
|
||||
}();
|
||||
|
||||
auto o_acc_result = make_static_distributed_tensor<OaccDataType>(o_acc_result_dist);
|
||||
clear_tile(o_acc_result);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
static_for<0, kNumWarps, 1>{}([&](auto) {
|
||||
auto o_acc_in = load_tile(o_acc_lds_load_window);
|
||||
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc_result)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc_result(i_j_idx) += o_acc_in(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_lds_load_window, {kM0, 0});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
return tile_elementwise_in(o_acc_element_func, o_acc_result);
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindow,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -52,11 +52,11 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kNumWarps;
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
@@ -78,16 +78,16 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc()
|
||||
{
|
||||
return sizeof(typename Problem::OaccDataType) *
|
||||
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
MakeOaccLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
|
||||
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc<Problem>();
|
||||
}
|
||||
|
||||
// shape=[kMaxSplits, kM0]
|
||||
@@ -129,8 +129,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
|
||||
@@ -142,8 +142,9 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NPack>{}, number<NPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -156,8 +157,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
|
||||
@@ -169,21 +170,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NPack>{}, number<NPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return lse_acc_t_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding, shape=[4 * kM0, kN1]
|
||||
// 3d + padding, shape=[kNumWarps * kM0, kN1]
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccLdsBlockDescriptor()
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = 4 * Problem::kM0;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = kNumWarps * Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
@@ -191,17 +194,17 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<NPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto o_acc_lds_block_desc = transform_tensor_descriptor(
|
||||
o_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return o_acc_t_lds_block_desc;
|
||||
return o_acc_lds_block_desc;
|
||||
}
|
||||
|
||||
// shape=[kM0, kMaxSplits]
|
||||
@@ -235,12 +238,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
|
||||
// direction
|
||||
// similar to MakeOaccResultDramTileDistribution(), but duplicate same 1-warp encoding kNumWarps
|
||||
// times on M direction
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (kNumWarps * kM0)
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
|
||||
|
||||
@@ -252,7 +256,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<kNumWarps, M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 2>, sequence<3, 0>>,
|
||||
sequence<1, 2>,
|
||||
@@ -260,14 +264,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccResultDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
|
||||
static_assert(kNumWarps * get_warp_size() <= kMPerBlock * kNPerBlock);
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kNumWarps;
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -204,6 +204,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem
|
||||
|
||||
using BaseType::kM0;
|
||||
using BaseType::kN1;
|
||||
using BaseType::NThreads;
|
||||
|
||||
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
|
||||
|
||||
@@ -216,7 +217,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem
|
||||
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
|
||||
static_assert(8 <= kMaxSplits);
|
||||
|
||||
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
|
||||
static constexpr index_t kNumWarps = 4;
|
||||
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
|
||||
|
||||
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
|
||||
|
||||
@@ -58,17 +58,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
using BlockGemm0 = remove_cvref_t<decltype(Policy::template GetQKBlockGemm<Problem>())>;
|
||||
static constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
static constexpr int NumMfmaInsts =
|
||||
(kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
@@ -298,7 +287,18 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm0 = [] {
|
||||
if constexpr(kQKHeaddim == 256)
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
|
||||
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
|
||||
{
|
||||
static_assert(NumMfmaInsts % 8 == 0);
|
||||
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
@@ -263,59 +263,33 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
|
||||
@@ -72,59 +72,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -238,7 +212,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
@@ -246,59 +220,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -481,7 +429,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
|
||||
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1019,15 +968,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::PDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user