[CK_TILE] FMHA bwd Support hdim as a Multiple of 32 (#2130)

* Fix shuffle_tile

* Add fmha bwd d160

* CHANGELOG

* Use static_cast

* Update

---------

Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
Yi DING
2025-07-29 09:31:14 +08:00
committed by GitHub
parent 7fe50dc3da
commit 1926cd0cb8
5 changed files with 446 additions and 141 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -182,7 +182,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
make_tuple(number<kN0>{}, number<kQKHeaddim>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegBlockDescriptor<Problem>());
@@ -208,7 +208,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
make_tuple(number<kN0>{}, number<kVHeaddim>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegBlockDescriptor<Problem>());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -22,6 +22,13 @@ namespace ck_tile {
struct BlockFmhaBwdPipelineDefaultPolicy
{
template <index_t ndim>
static constexpr auto swap_last2 = generate_sequence_v2(
[](auto i) {
return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {};
},
number<ndim>{});
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
@@ -384,13 +391,40 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N2 = kNPerBlock / (N1 * N0);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
{
return dstr;
}
else
{
constexpr index_t kKPerIter = 32;
static_assert(kKPerBlock % kKPerIter == 0);
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t N1_m = get_warp_size() / K1_m;
constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<N0, N1_m, N2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
tuple<sequence<0>, sequence<1, 1>>,
sequence<2, 1, 2>, // K0 N2 K2
sequence<0, 2, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr_m;
}
}
template <typename Problem>
@@ -407,13 +441,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 2>, // N0 K1
sequence<0, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
{
return dstr;
}
else
{
constexpr index_t kKPerIter = 32;
static_assert(kKPerBlock % kKPerIter == 0);
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t N2_m = get_warp_size() / K1_m;
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>, // K0 N0 K2
sequence<0, 0, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr_m;
}
}
template <typename Problem>
@@ -430,13 +490,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock)
{
return dstr;
}
else
{
// something not divisible, try a more flexible distribution
constexpr index_t kKPerIter = 32;
static_assert(kKPerBlock % kKPerIter == 0);
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t M1_m = get_warp_size() / K1_m;
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
tuple<sequence<0>, sequence<1, 1>>,
sequence<2, 1, 2>, // K0 M2 K2
sequence<0, 2, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock);
return dstr_m;
}
}
template <typename Problem>
@@ -453,13 +541,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock)
{
return dstr;
}
else
{
// something not divisible, try a more flexible distribution
constexpr index_t kKPerIter = 32;
static_assert(kKPerBlock % kKPerIter == 0);
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t M1_m = get_warp_size() / K1_m;
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
tuple<sequence<0>, sequence<1, 1>>,
sequence<2, 1, 2>, // K0 M2 K2
sequence<0, 2, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock);
return dstr_m;
}
}
template <typename Problem, typename BlockGemm>
@@ -504,13 +620,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kNPerBlock);
return dstr;
}
template <typename DataType, index_t MPerBlock, index_t KPerBlock>
@@ -522,13 +641,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M1 = get_warp_size();
constexpr index_t M0 = MPerBlock / M1;
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>,
tuple<sequence<0>, sequence<1>>,
sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
MPerBlock * KPerBlock);
return dstr;
}
template <typename Problem>
@@ -569,13 +691,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 3>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock);
return dstr;
}
template <typename Problem>
@@ -594,13 +719,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock);
return dstr;
}
// these are for lds
@@ -666,56 +795,80 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return 16 / sizeof(GemmDataType);
}
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
template <index_t KIter, index_t MNPerBlock, index_t KPerSubBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
constexpr auto MNLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
(32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize);
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
number<MNPerBlock / MNLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto x_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<KIter>{},
number<KPerSubBlock / KPack * MNLdsLayer>{},
number<MNPerBlock / MNLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPerSubBlock * MNPerBlock>{},
number<KPack>{},
number<KPerSubBlock * MNLdsLayer>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
number<KPerBlock / KPack * MNLdsLayer>{})),
make_tuple(make_pass_through_transform(number<KIter>{}),
make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
number<KPerSubBlock / KPack * MNLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}));
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
x_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
make_tuple(make_pass_through_transform(number<KIter>{}),
make_unmerge_transform(
make_tuple(number<KPerSubBlock / KPack>{}, number<MNLdsLayer>{})),
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<KIter>{}, number<KPerSubBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<2, 3>{}, sequence<0, 1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
static_assert(container_reduce(x_lds_block_desc.get_lengths(),
std::multiplies<index_t>{},
1) == KIter * MNPerBlock * KPerSubBlock);
return x_lds_block_desc;
}
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{
return MakeXLdsBlockDescriptor<1, MNPerBlock, KPerBlock, KPack>();
}
template <typename Problem,
index_t MNPerBlock,
index_t KPerBlock,
index_t KPack,
index_t KPackT>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
{
return MakeXTLdsBlockDescriptor<Problem, 1, MNPerBlock, KPerBlock, KPack, KPackT>();
}
template <typename Problem,
index_t MNIter,
index_t MNPerSubBlock,
index_t KPerBlock,
index_t KPack,
index_t KPackT>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
@@ -723,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr auto kBlockSize = Problem::kBlockSize;
constexpr auto MN0 = MNPerBlock / KPack;
constexpr auto MN0 = MNPerSubBlock / KPack;
constexpr auto MN1 = KPack;
constexpr auto KThreadWrite = kBlockSize / MN0;
@@ -745,13 +898,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
make_tuple(number<MNIter>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * MN1>{},
number<kfold * MN0 / mnpair>{},
number<mnpair>{},
KPackT),
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
make_tuple(number<KPackT * MN0 * KThreadWrite * MN1 * K0PerThreadWrite>{},
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
number<KPackT * kfold * MN0>{},
number<KPackT * mnpair>{},
@@ -763,20 +918,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
xt_lds_block_desc_raw,
make_tuple(
make_pass_through_transform(number<MNIter>{}),
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3, 4>{},
sequence<5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3, 4>{},
sequence<5>{},
sequence<6>{}));
constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
xt_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<MNIter>{}),
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
@@ -788,27 +953,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
sequence<3>{},
sequence<1, 4>{},
sequence<5, 6>{},
sequence<7>{},
sequence<8>{}));
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
xt_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
number<KPackT>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
number<KPackT>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<MNIter>{}, number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
make_tuple(sequence<1, 2, 5, 3, 8>{}, sequence<0, 6, 7, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
static_assert(container_reduce(xt_lds_block_desc.get_lengths(),
std::multiplies<index_t>{},
1) == MNPerSubBlock * MNIter * KPerBlock);
return xt_lds_block_desc;
}
@@ -817,9 +987,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kKPack>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -850,7 +1035,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return k_block_dstr;
}
@@ -860,9 +1046,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
using dram_encoding = typename decltype(MakeVDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kVPack>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -893,30 +1093,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
static_assert(y_ndim >= 2);
using shuffled_encoding_t =
tile_distribution_encoding_shuffle_t<dram_encoding,
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
return make_static_tile_distribution(shuffled_encoding_t{});
}
template <typename Problem>
@@ -926,10 +1117,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
return MakeXTLdsBlockDescriptor<Problem,
KIter,
kNPerBlock / KIter,
kKPerBlock,
kKPack,
kKPackT>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -976,7 +1187,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
static_assert(container_reduce(kt_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return kt_block_dstr;
}
@@ -986,9 +1199,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -1019,30 +1246,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock);
return q_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
static_assert(y_ndim >= 2);
using shuffled_encoding_t =
tile_distribution_encoding_shuffle_t<dram_encoding,
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
return make_static_tile_distribution(shuffled_encoding_t{});
}
template <typename Problem>
@@ -1052,10 +1270,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
return MakeXTLdsBlockDescriptor<Problem,
KIter,
kNPerBlock / KIter,
kKPerBlock,
kKPack,
kKPackT>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -1103,6 +1341,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
static_assert(container_reduce(qt_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return qt_block_dstr;
}
@@ -1135,7 +1376,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
static_assert(container_reduce(dst_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kMPerBlock * kKPerBlock);
return dst_block_dstr;
}
@@ -1177,13 +1420,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock);
return dstr;
}
template <typename Problem>
@@ -1193,9 +1439,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
using dram_encoding =
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -1226,30 +1487,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
static_assert(container_reduce(do_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kMPerBlock * kKPerBlock);
return do_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
using dram_encoding =
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
static_assert(y_ndim >= 2);
using shuffled_encoding_t =
tile_distribution_encoding_shuffle_t<dram_encoding,
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
return make_static_tile_distribution(shuffled_encoding_t{});
}
template <typename Problem>
@@ -1259,10 +1514,31 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
using dram_encoding =
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
if constexpr(dram_y_ndim == 2)
{
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
else if constexpr(dram_y_ndim == 3)
{
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
return MakeXTLdsBlockDescriptor<Problem,
KIter,
kNPerBlock / KIter,
kKPerBlock,
kKPack,
kKPackT>();
}
else
{
static_assert(false, "Unexpected dram y dimension");
}
}
template <typename Problem>
@@ -1310,7 +1586,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
static_assert(container_reduce(dot_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return dot_block_dstr;
}
@@ -1342,7 +1620,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
static_assert(container_reduce(pt_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kMPerBlock * kKPerBlock);
return pt_block_dstr;
}
@@ -1384,7 +1664,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
static_assert(container_reduce(ds_block_dstr.get_lengths(),
std::multiplies<index_t>{},
1) == kMPerBlock * kKPerBlock);
return ds_block_dstr;
}