mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
bm0=32 not correct, support hdim48, performance not good
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
# "fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
# "fp8" : "FmhaFwdFp8",
|
||||
# "fp8fp16": "FmhaFwdFp8Fp16",
|
||||
|
||||
@@ -31,6 +31,7 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
48 : 48,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
@@ -303,18 +304,18 @@ FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_
|
||||
|
||||
// make sure we can reuse the padding flags in combine kernels
|
||||
static_assert({F_bm0} % kM0 == 0);
|
||||
static_assert({F_bn1} % 32 == 0);
|
||||
static_assert({F_bn1} % 16 == 0);
|
||||
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_decode_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_decode_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
@@ -649,17 +650,19 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'64': {
|
||||
'48': {
|
||||
# # Specialize for different SeqQ
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64': FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
},
|
||||
'128': {
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64': FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'16': FmhaFwdTileSize(16, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 32, 32, 16, 16, 16, 32, -1),
|
||||
},
|
||||
# '64': {
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# },
|
||||
# '128': {
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# },
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -668,9 +671,10 @@ def get_fmha_fwd_decode_combine_tile_dict_from_dtype(dtype : str) -> Optional[di
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '32' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'48' : FmhaFwdSplitKVCombineTileSize(16, -1),
|
||||
# '64' : FmhaFwdSplitKVCombineTileSize(16, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
# '128' : FmhaFwdSplitKVCombineTileSize(16, -1),
|
||||
# '256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
else:
|
||||
@@ -692,7 +696,8 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["f"]):
|
||||
for lse in ['t', 'f']:
|
||||
if hdim in [64, 128]: ### [32, 64, 96, 128]:
|
||||
# for lse in ['f']:
|
||||
if hdim in [48, 64, 128]: ### [32, 64, 96, 128]:
|
||||
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask))
|
||||
else:
|
||||
|
||||
@@ -76,3 +76,7 @@
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
|
||||
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
|
||||
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
|
||||
// Set BUILD_DEV to OFF to avoid enabling Werror
|
||||
|
||||
@@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode;
|
||||
// using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
constexpr auto y_dim_out_to_in = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
return number<rh_major_minor_to_y_in[{rh_major_out, rh_minor_out}]>{};
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
@@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
constexpr auto idx_y_out_tmp = generate_tuple(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
|
||||
@@ -229,4 +229,13 @@ constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
}
|
||||
}
|
||||
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -676,6 +676,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
// CK_PRINT<decltype(v_shuffle_tmp), decltype(v_prefetch)>();
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
|
||||
@@ -48,18 +48,49 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMaxVecLoad = [&](){
|
||||
// Try dwordx4
|
||||
if constexpr (ElemPerThread % MaxVectorSize == 0){
|
||||
return MaxVectorSize;
|
||||
}
|
||||
// Try dwordx2
|
||||
else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){
|
||||
return MaxVectorSize / 2;
|
||||
}
|
||||
// Try dword
|
||||
else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){
|
||||
return MaxVectorSize / 4;
|
||||
}
|
||||
else{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
// if false, we can not distribute the Ps thread evenly over Hs.
|
||||
constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0;
|
||||
constexpr index_t KThreads = [&](){
|
||||
if constexpr (KThreadEven){
|
||||
return kKPerBlock / KPerThread;
|
||||
}
|
||||
else{
|
||||
// We move the odd factor to multiple instruction issue
|
||||
return kKPerBlock/ElemPerThread;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread;
|
||||
static_assert(KRepeat == ElemPerThread/kMaxVecLoad);
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
@@ -67,11 +98,154 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
sequence<KRepeat, KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad = [&](){
|
||||
// Try dwordx4
|
||||
if constexpr (ElemPerThread % MaxVectorSize == 0){
|
||||
return MaxVectorSize;
|
||||
}
|
||||
// Try dwordx2
|
||||
else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){
|
||||
return MaxVectorSize / 2;
|
||||
}
|
||||
// Try dword
|
||||
else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){
|
||||
return MaxVectorSize / 4;
|
||||
}
|
||||
else{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
// if false, we can not distribute the Ps thread evenly over Hs.
|
||||
constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0;
|
||||
constexpr index_t KThreads = [&](){
|
||||
if constexpr (KThreadEven){
|
||||
return kKPerBlock / KPerThread;
|
||||
}
|
||||
else{
|
||||
// We move the odd factor to multiple instruction issue
|
||||
return kKPerBlock/ElemPerThread;
|
||||
}}();
|
||||
|
||||
constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread;
|
||||
static_assert(KRepeat == ElemPerThread/kMaxVecLoad);
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KRepeat, KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
// constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N2 = 4; // Y 4
|
||||
constexpr index_t N1 = 4; // P 2
|
||||
constexpr index_t N0 = 3; // Y 2
|
||||
|
||||
constexpr index_t K2 = 2; // Y
|
||||
constexpr index_t K1 = 16; // P
|
||||
constexpr index_t K0 = 1; // P
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>,
|
||||
sequence<0, 2, 2>>{});
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P // 2
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
|
||||
{
|
||||
constexpr index_t N2 = 4; // Y 4
|
||||
constexpr index_t N1 = 4; // P 2
|
||||
constexpr index_t N0 = 3; // Y 2
|
||||
|
||||
constexpr index_t K2 = 2; // Y
|
||||
constexpr index_t K1 = 16; // P
|
||||
constexpr index_t K0 = 1; // P
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>,
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -73,6 +73,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -81,8 +82,12 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16){
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
|
||||
else
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
@@ -92,8 +97,12 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16){
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
|
||||
else
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
@@ -239,7 +248,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
@@ -250,7 +259,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ namespace ck_tile {
|
||||
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
|
||||
{
|
||||
if(len == 48)
|
||||
return 48;
|
||||
if(len == 96)
|
||||
return 128;
|
||||
if(len == 160)
|
||||
|
||||
Reference in New Issue
Block a user