mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] FMHA bwd Support hdim as a Multiple of 32 (#2130)
* Fix shuffle_tile * Add fmha bwd d160 * CHANGELOG * Use static_cast * Update --------- Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -182,7 +182,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
make_tuple(number<kN0>{}, number<kQKHeaddim>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -208,7 +208,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
make_tuple(number<kN0>{}, number<kVHeaddim>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,6 +22,13 @@ namespace ck_tile {
|
||||
|
||||
struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
template <index_t ndim>
|
||||
static constexpr auto swap_last2 = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {};
|
||||
},
|
||||
number<ndim>{});
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
@@ -384,13 +391,40 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = kNPerBlock / (N1 * N0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kKPerIter = 32;
|
||||
static_assert(kKPerBlock % kKPerIter == 0);
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t N1_m = get_warp_size() / K1_m;
|
||||
constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<N0, N1_m, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -407,13 +441,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>, // N0 K1
|
||||
sequence<0, 1>>{});
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kKPerIter = 32;
|
||||
static_assert(kKPerBlock % kKPerIter == 0);
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t N2_m = get_warp_size() / K1_m;
|
||||
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N0 K2
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -430,13 +490,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
// something not divisible, try a more flexible distribution
|
||||
constexpr index_t kKPerIter = 32;
|
||||
static_assert(kKPerBlock % kKPerIter == 0);
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t M1_m = get_warp_size() / K1_m;
|
||||
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -453,13 +541,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
// something not divisible, try a more flexible distribution
|
||||
constexpr index_t kKPerIter = 32;
|
||||
static_assert(kKPerBlock % kKPerIter == 0);
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t M1_m = get_warp_size() / K1_m;
|
||||
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
@@ -504,13 +620,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kNPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
|
||||
template <typename DataType, index_t MPerBlock, index_t KPerBlock>
|
||||
@@ -522,13 +641,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M1 = get_warp_size();
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1>>,
|
||||
tuple<sequence<0>, sequence<1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<2, 0, 1>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
MPerBlock * KPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -569,13 +691,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<2>, sequence<2, 3>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2, 3>,
|
||||
sequence<0, 0, 1>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -594,13 +719,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
|
||||
// these are for lds
|
||||
@@ -666,56 +795,80 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return 16 / sizeof(GemmDataType);
|
||||
}
|
||||
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
|
||||
template <index_t KIter, index_t MNPerBlock, index_t KPerSubBlock, index_t KPack>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
|
||||
constexpr auto MNLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
(32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize);
|
||||
|
||||
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
|
||||
number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr auto x_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<KIter>{},
|
||||
number<KPerSubBlock / KPack * MNLdsLayer>{},
|
||||
number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPerSubBlock * MNPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<KPerSubBlock * MNLdsLayer>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
x_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPerBlock / KPack * MNLdsLayer>{})),
|
||||
make_tuple(make_pass_through_transform(number<KIter>{}),
|
||||
make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPerSubBlock / KPack * MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
x_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
|
||||
make_tuple(make_pass_through_transform(number<KIter>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<KPerSubBlock / KPack>{}, number<MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
|
||||
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
|
||||
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<KIter>{}, number<KPerSubBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<2, 3>{}, sequence<0, 1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
static_assert(container_reduce(x_lds_block_desc.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == KIter * MNPerBlock * KPerSubBlock);
|
||||
return x_lds_block_desc;
|
||||
}
|
||||
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
||||
{
|
||||
return MakeXLdsBlockDescriptor<1, MNPerBlock, KPerBlock, KPack>();
|
||||
}
|
||||
template <typename Problem,
|
||||
index_t MNPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t KPack,
|
||||
index_t KPackT>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
|
||||
{
|
||||
return MakeXTLdsBlockDescriptor<Problem, 1, MNPerBlock, KPerBlock, KPack, KPackT>();
|
||||
}
|
||||
template <typename Problem,
|
||||
index_t MNIter,
|
||||
index_t MNPerSubBlock,
|
||||
index_t KPerBlock,
|
||||
index_t KPack,
|
||||
index_t KPackT>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
@@ -723,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr auto kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr auto MN0 = MNPerBlock / KPack;
|
||||
constexpr auto MN0 = MNPerSubBlock / KPack;
|
||||
constexpr auto MN1 = KPack;
|
||||
|
||||
constexpr auto KThreadWrite = kBlockSize / MN0;
|
||||
@@ -745,13 +898,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
|
||||
|
||||
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
make_tuple(number<MNIter>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * MN1>{},
|
||||
number<kfold * MN0 / mnpair>{},
|
||||
number<mnpair>{},
|
||||
KPackT),
|
||||
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
|
||||
make_tuple(number<KPackT * MN0 * KThreadWrite * MN1 * K0PerThreadWrite>{},
|
||||
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
|
||||
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
|
||||
number<KPackT * kfold * MN0>{},
|
||||
number<KPackT * mnpair>{},
|
||||
@@ -763,20 +918,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
xt_lds_block_desc_raw,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<MNIter>{}),
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
|
||||
make_pass_through_transform(number<mnpair>{}),
|
||||
make_pass_through_transform(KPackT)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
|
||||
constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
xt_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<MNIter>{}),
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
|
||||
@@ -788,27 +953,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
sequence<3>{},
|
||||
sequence<1, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
|
||||
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
|
||||
xt_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KPackT>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KPackT>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MNIter>{}, number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
|
||||
make_tuple(sequence<1, 2, 5, 3, 8>{}, sequence<0, 6, 7, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
static_assert(container_reduce(xt_lds_block_desc.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == MNPerSubBlock * MNIter * KPerBlock);
|
||||
return xt_lds_block_desc;
|
||||
}
|
||||
|
||||
@@ -817,9 +987,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kKPack>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -850,7 +1035,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
@@ -860,9 +1046,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
using dram_encoding = typename decltype(MakeVDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kVPack>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -893,30 +1093,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
static_assert(y_ndim >= 2);
|
||||
using shuffled_encoding_t =
|
||||
tile_distribution_encoding_shuffle_t<dram_encoding,
|
||||
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
||||
return make_static_tile_distribution(shuffled_encoding_t{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -926,10 +1117,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
||||
return MakeXTLdsBlockDescriptor<Problem,
|
||||
KIter,
|
||||
kNPerBlock / KIter,
|
||||
kKPerBlock,
|
||||
kKPack,
|
||||
kKPackT>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -976,7 +1187,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(kt_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kNPerBlock * kKPerBlock);
|
||||
return kt_block_dstr;
|
||||
}
|
||||
|
||||
@@ -986,9 +1199,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1019,30 +1246,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return q_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
static_assert(y_ndim >= 2);
|
||||
using shuffled_encoding_t =
|
||||
tile_distribution_encoding_shuffle_t<dram_encoding,
|
||||
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
||||
return make_static_tile_distribution(shuffled_encoding_t{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1052,10 +1270,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
||||
return MakeXTLdsBlockDescriptor<Problem,
|
||||
KIter,
|
||||
kNPerBlock / KIter,
|
||||
kKPerBlock,
|
||||
kKPack,
|
||||
kKPackT>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1103,6 +1341,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
|
||||
static_assert(container_reduce(qt_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kNPerBlock * kKPerBlock);
|
||||
|
||||
return qt_block_dstr;
|
||||
}
|
||||
@@ -1135,7 +1376,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(dst_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kMPerBlock * kKPerBlock);
|
||||
return dst_block_dstr;
|
||||
}
|
||||
|
||||
@@ -1177,13 +1420,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N0, N1>,
|
||||
tuple<sequence<M0, M1, M2, M3, M4>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>, sequence<3, 1>>,
|
||||
sequence<1, 1, 1>,
|
||||
sequence<0, 2, 4>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1193,9 +1439,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
using dram_encoding =
|
||||
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1226,30 +1487,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(do_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kMPerBlock * kKPerBlock);
|
||||
return do_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
using dram_encoding =
|
||||
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
static_assert(y_ndim >= 2);
|
||||
using shuffled_encoding_t =
|
||||
tile_distribution_encoding_shuffle_t<dram_encoding,
|
||||
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
||||
return make_static_tile_distribution(shuffled_encoding_t{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1259,10 +1514,31 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
using dram_encoding =
|
||||
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
||||
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
||||
if constexpr(dram_y_ndim == 2)
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
}
|
||||
else if constexpr(dram_y_ndim == 3)
|
||||
{
|
||||
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
||||
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
||||
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
||||
return MakeXTLdsBlockDescriptor<Problem,
|
||||
KIter,
|
||||
kNPerBlock / KIter,
|
||||
kKPerBlock,
|
||||
kKPack,
|
||||
kKPackT>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected dram y dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1310,7 +1586,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(dot_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kNPerBlock * kKPerBlock);
|
||||
return dot_block_dstr;
|
||||
}
|
||||
|
||||
@@ -1342,7 +1620,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(pt_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kMPerBlock * kKPerBlock);
|
||||
return pt_block_dstr;
|
||||
}
|
||||
|
||||
@@ -1384,7 +1664,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
|
||||
|
||||
static_assert(container_reduce(ds_block_dstr.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kMPerBlock * kKPerBlock);
|
||||
return ds_block_dstr;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user