|
|
|
|
@@ -1,5 +1,5 @@
|
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
@@ -22,6 +22,13 @@ namespace ck_tile {
|
|
|
|
|
|
|
|
|
|
struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
{
|
|
|
|
|
template <index_t ndim>
|
|
|
|
|
static constexpr auto swap_last2 = generate_sequence_v2(
|
|
|
|
|
[](auto i) {
|
|
|
|
|
return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {};
|
|
|
|
|
},
|
|
|
|
|
number<ndim>{});
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
|
|
|
|
{
|
|
|
|
|
@@ -384,13 +391,40 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t N0 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t N2 = kNPerBlock / (N1 * N0);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<2, 1>>{});
|
|
|
|
|
|
|
|
|
|
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock)
|
|
|
|
|
{
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPerIter = 32;
|
|
|
|
|
static_assert(kKPerBlock % kKPerIter == 0);
|
|
|
|
|
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
|
|
|
|
constexpr index_t K2 = 2;
|
|
|
|
|
constexpr index_t K1_m = kKPerIter / K2;
|
|
|
|
|
constexpr index_t N1_m = get_warp_size() / K1_m;
|
|
|
|
|
constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
|
|
|
|
|
constexpr auto dstr_m = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<
|
|
|
|
|
sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1_m, N2_m>, sequence<K0_m, K1_m, K2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
|
|
|
|
|
tuple<sequence<0>, sequence<1, 1>>,
|
|
|
|
|
sequence<2, 1, 2>, // K0 N2 K2
|
|
|
|
|
sequence<0, 2, 2>>{});
|
|
|
|
|
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock);
|
|
|
|
|
return dstr_m;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -407,13 +441,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t N1 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
|
|
|
|
|
tuple<sequence<1>, sequence<2, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<1, 2>, // N0 K1
|
|
|
|
|
sequence<0, 1>>{});
|
|
|
|
|
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock)
|
|
|
|
|
{
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPerIter = 32;
|
|
|
|
|
static_assert(kKPerBlock % kKPerIter == 0);
|
|
|
|
|
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
|
|
|
|
constexpr index_t K2 = 2;
|
|
|
|
|
constexpr index_t K1_m = kKPerIter / K2;
|
|
|
|
|
constexpr index_t N2_m = get_warp_size() / K1_m;
|
|
|
|
|
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
|
|
|
|
|
constexpr auto dstr_m = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<
|
|
|
|
|
sequence<>,
|
|
|
|
|
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
|
|
|
|
tuple<sequence<1>, sequence<2, 1>>,
|
|
|
|
|
sequence<2, 1, 2>, // K0 N0 K2
|
|
|
|
|
sequence<0, 0, 2>>{});
|
|
|
|
|
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock);
|
|
|
|
|
return dstr_m;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -430,13 +490,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M0 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<2, 1>>{});
|
|
|
|
|
|
|
|
|
|
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock)
|
|
|
|
|
{
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
// something not divisible, try a more flexible distribution
|
|
|
|
|
constexpr index_t kKPerIter = 32;
|
|
|
|
|
static_assert(kKPerBlock % kKPerIter == 0);
|
|
|
|
|
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
|
|
|
|
constexpr index_t K2 = 2;
|
|
|
|
|
constexpr index_t K1_m = kKPerIter / K2;
|
|
|
|
|
constexpr index_t M1_m = get_warp_size() / K1_m;
|
|
|
|
|
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
|
|
|
|
constexpr auto dstr_m = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<
|
|
|
|
|
sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
|
|
|
|
tuple<sequence<0>, sequence<1, 1>>,
|
|
|
|
|
sequence<2, 1, 2>, // K0 M2 K2
|
|
|
|
|
sequence<0, 2, 2>>{});
|
|
|
|
|
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock);
|
|
|
|
|
return dstr_m;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -453,13 +541,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M0 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<2, 1>>{});
|
|
|
|
|
|
|
|
|
|
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock)
|
|
|
|
|
{
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
// something not divisible, try a more flexible distribution
|
|
|
|
|
constexpr index_t kKPerIter = 32;
|
|
|
|
|
static_assert(kKPerBlock % kKPerIter == 0);
|
|
|
|
|
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
|
|
|
|
constexpr index_t K2 = 2;
|
|
|
|
|
constexpr index_t K1_m = kKPerIter / K2;
|
|
|
|
|
constexpr index_t M1_m = get_warp_size() / K1_m;
|
|
|
|
|
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
|
|
|
|
constexpr auto dstr_m = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<
|
|
|
|
|
sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
|
|
|
|
tuple<sequence<0>, sequence<1, 1>>,
|
|
|
|
|
sequence<2, 1, 2>, // K0 M2 K2
|
|
|
|
|
sequence<0, 2, 2>>{});
|
|
|
|
|
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock);
|
|
|
|
|
return dstr_m;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem, typename BlockGemm>
|
|
|
|
|
@@ -504,13 +620,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M0 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<2, 1>>{});
|
|
|
|
|
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kNPerBlock);
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DataType, index_t MPerBlock, index_t KPerBlock>
|
|
|
|
|
@@ -522,13 +641,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M1 = get_warp_size();
|
|
|
|
|
constexpr index_t M0 = MPerBlock / M1;
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1>>,
|
|
|
|
|
sequence<1, 2, 2>,
|
|
|
|
|
sequence<2, 0, 1>>{});
|
|
|
|
|
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
MPerBlock * KPerBlock);
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -569,13 +691,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M1 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<2>, sequence<2, 3>>,
|
|
|
|
|
tuple<sequence<1>, sequence<2, 0>>,
|
|
|
|
|
sequence<1, 2, 3>,
|
|
|
|
|
sequence<0, 0, 1>>{});
|
|
|
|
|
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock);
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -594,13 +719,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M1 = kBlockSize / get_warp_size();
|
|
|
|
|
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<1>, sequence<2, 0>>,
|
|
|
|
|
sequence<1, 2>,
|
|
|
|
|
sequence<0, 1>>{});
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock);
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// these are for lds
|
|
|
|
|
@@ -666,56 +795,80 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
return 16 / sizeof(GemmDataType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
|
|
|
|
|
template <index_t KIter, index_t MNPerBlock, index_t KPerSubBlock, index_t KPack>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
|
|
|
|
|
constexpr auto MNLdsLayer =
|
|
|
|
|
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
|
|
|
|
(32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize);
|
|
|
|
|
|
|
|
|
|
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
|
|
|
|
|
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
|
|
|
|
|
number<MNPerBlock / MNLdsLayer>{},
|
|
|
|
|
number<KPack>{}),
|
|
|
|
|
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
|
|
|
|
|
number<KPack>{},
|
|
|
|
|
number<1>{});
|
|
|
|
|
constexpr auto x_lds_block_desc_0 =
|
|
|
|
|
make_naive_tensor_descriptor(make_tuple(number<KIter>{},
|
|
|
|
|
number<KPerSubBlock / KPack * MNLdsLayer>{},
|
|
|
|
|
number<MNPerBlock / MNLdsLayer>{},
|
|
|
|
|
number<KPack>{}),
|
|
|
|
|
make_tuple(number<KPerSubBlock * MNPerBlock>{},
|
|
|
|
|
number<KPack>{},
|
|
|
|
|
number<KPerSubBlock * MNLdsLayer>{},
|
|
|
|
|
number<1>{}),
|
|
|
|
|
number<KPack>{},
|
|
|
|
|
number<1>{});
|
|
|
|
|
|
|
|
|
|
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
|
|
|
|
|
x_lds_block_desc_0,
|
|
|
|
|
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
|
|
|
|
number<KPerBlock / KPack * MNLdsLayer>{})),
|
|
|
|
|
make_tuple(make_pass_through_transform(number<KIter>{}),
|
|
|
|
|
make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
|
|
|
|
number<KPerSubBlock / KPack * MNLdsLayer>{})),
|
|
|
|
|
make_pass_through_transform(number<KPack>{})),
|
|
|
|
|
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
|
|
|
|
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}),
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}));
|
|
|
|
|
|
|
|
|
|
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
|
|
|
|
x_lds_block_desc_permuted,
|
|
|
|
|
make_tuple(make_unmerge_transform(
|
|
|
|
|
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
|
|
|
|
|
make_tuple(make_pass_through_transform(number<KIter>{}),
|
|
|
|
|
make_unmerge_transform(
|
|
|
|
|
make_tuple(number<KPerSubBlock / KPack>{}, number<MNLdsLayer>{})),
|
|
|
|
|
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
|
|
|
|
|
make_pass_through_transform(number<KPack>{})),
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
|
|
|
|
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
|
|
|
|
|
|
|
|
|
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
|
|
|
|
|
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
|
|
|
|
make_tuple(make_merge_transform_v3_division_mod(
|
|
|
|
|
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
|
|
|
|
|
make_merge_transform_v3_division_mod(
|
|
|
|
|
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
|
|
|
|
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
|
|
|
|
make_merge_transform_v3_division_mod(make_tuple(
|
|
|
|
|
number<KIter>{}, number<KPerSubBlock / KPack>{}, number<KPack>{}))),
|
|
|
|
|
make_tuple(sequence<2, 3>{}, sequence<0, 1, 4>{}),
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<1>{}));
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(x_lds_block_desc.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == KIter * MNPerBlock * KPerSubBlock);
|
|
|
|
|
return x_lds_block_desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
return MakeXLdsBlockDescriptor<1, MNPerBlock, KPerBlock, KPack>();
|
|
|
|
|
}
|
|
|
|
|
template <typename Problem,
|
|
|
|
|
index_t MNPerBlock,
|
|
|
|
|
index_t KPerBlock,
|
|
|
|
|
index_t KPack,
|
|
|
|
|
index_t KPackT>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, 1, MNPerBlock, KPerBlock, KPack, KPackT>();
|
|
|
|
|
}
|
|
|
|
|
template <typename Problem,
|
|
|
|
|
index_t MNIter,
|
|
|
|
|
index_t MNPerSubBlock,
|
|
|
|
|
index_t KPerBlock,
|
|
|
|
|
index_t KPack,
|
|
|
|
|
index_t KPackT>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
// kfold and mpair dimension is not always required.
|
|
|
|
|
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
|
|
|
|
@@ -723,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
|
|
|
|
constexpr auto kBlockSize = Problem::kBlockSize;
|
|
|
|
|
|
|
|
|
|
constexpr auto MN0 = MNPerBlock / KPack;
|
|
|
|
|
constexpr auto MN0 = MNPerSubBlock / KPack;
|
|
|
|
|
constexpr auto MN1 = KPack;
|
|
|
|
|
|
|
|
|
|
constexpr auto KThreadWrite = kBlockSize / MN0;
|
|
|
|
|
@@ -745,13 +898,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
|
|
|
|
|
|
|
|
|
|
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
|
|
|
|
|
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
|
|
|
|
make_tuple(number<MNIter>{},
|
|
|
|
|
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
|
|
|
|
number<K0PerThreadWrite>{},
|
|
|
|
|
number<KThreadReadPerm * MN1>{},
|
|
|
|
|
number<kfold * MN0 / mnpair>{},
|
|
|
|
|
number<mnpair>{},
|
|
|
|
|
KPackT),
|
|
|
|
|
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
|
|
|
|
|
make_tuple(number<KPackT * MN0 * KThreadWrite * MN1 * K0PerThreadWrite>{},
|
|
|
|
|
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
|
|
|
|
|
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
|
|
|
|
|
number<KPackT * kfold * MN0>{},
|
|
|
|
|
number<KPackT * mnpair>{},
|
|
|
|
|
@@ -763,20 +918,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
|
|
|
|
|
xt_lds_block_desc_raw,
|
|
|
|
|
make_tuple(
|
|
|
|
|
make_pass_through_transform(number<MNIter>{}),
|
|
|
|
|
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
|
|
|
|
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
|
|
|
|
make_xor_transform(
|
|
|
|
|
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
|
|
|
|
|
make_pass_through_transform(number<mnpair>{}),
|
|
|
|
|
make_pass_through_transform(KPackT)),
|
|
|
|
|
make_tuple(
|
|
|
|
|
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
|
|
|
|
make_tuple(
|
|
|
|
|
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
|
|
|
|
make_tuple(sequence<0>{},
|
|
|
|
|
sequence<1>{},
|
|
|
|
|
sequence<2>{},
|
|
|
|
|
sequence<3, 4>{},
|
|
|
|
|
sequence<5>{},
|
|
|
|
|
sequence<6>{}),
|
|
|
|
|
make_tuple(sequence<0>{},
|
|
|
|
|
sequence<1>{},
|
|
|
|
|
sequence<2>{},
|
|
|
|
|
sequence<3, 4>{},
|
|
|
|
|
sequence<5>{},
|
|
|
|
|
sequence<6>{}));
|
|
|
|
|
|
|
|
|
|
constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
|
|
|
|
|
xt_lds_block_desc_permuted,
|
|
|
|
|
make_tuple(
|
|
|
|
|
make_pass_through_transform(number<MNIter>{}),
|
|
|
|
|
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
|
|
|
|
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
|
|
|
|
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
|
|
|
|
|
@@ -788,27 +953,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
sequence<2>{},
|
|
|
|
|
sequence<3>{},
|
|
|
|
|
sequence<4>{},
|
|
|
|
|
sequence<5>{}),
|
|
|
|
|
make_tuple(sequence<1>{},
|
|
|
|
|
sequence<5>{},
|
|
|
|
|
sequence<6>{}),
|
|
|
|
|
make_tuple(sequence<0>{},
|
|
|
|
|
sequence<2>{},
|
|
|
|
|
sequence<0, 3>{},
|
|
|
|
|
sequence<4, 5>{},
|
|
|
|
|
sequence<6>{},
|
|
|
|
|
sequence<7>{}));
|
|
|
|
|
sequence<3>{},
|
|
|
|
|
sequence<1, 4>{},
|
|
|
|
|
sequence<5, 6>{},
|
|
|
|
|
sequence<7>{},
|
|
|
|
|
sequence<8>{}));
|
|
|
|
|
|
|
|
|
|
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
|
|
|
|
|
xt_lds_block_desc_unmerged,
|
|
|
|
|
make_tuple(make_merge_transform_v3_division_mod(
|
|
|
|
|
make_tuple(number<KThreadReadPerm>{},
|
|
|
|
|
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
|
|
|
|
number<kfold>{},
|
|
|
|
|
number<K0PerThreadWrite>{},
|
|
|
|
|
number<KPackT>{})),
|
|
|
|
|
make_merge_transform_v3_division_mod(
|
|
|
|
|
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
|
|
|
|
|
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
|
|
|
|
make_tuple(
|
|
|
|
|
make_merge_transform_v3_division_mod(
|
|
|
|
|
make_tuple(number<KThreadReadPerm>{},
|
|
|
|
|
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
|
|
|
|
number<kfold>{},
|
|
|
|
|
number<K0PerThreadWrite>{},
|
|
|
|
|
number<KPackT>{})),
|
|
|
|
|
make_merge_transform_v3_division_mod(make_tuple(
|
|
|
|
|
number<MNIter>{}, number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
|
|
|
|
|
make_tuple(sequence<1, 2, 5, 3, 8>{}, sequence<0, 6, 7, 4>{}),
|
|
|
|
|
make_tuple(sequence<0>{}, sequence<1>{}));
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(xt_lds_block_desc.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == MNPerSubBlock * MNIter * KPerBlock);
|
|
|
|
|
return xt_lds_block_desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -817,9 +987,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
|
|
|
|
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -850,7 +1035,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock);
|
|
|
|
|
return k_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -860,9 +1046,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
|
|
|
|
using dram_encoding = typename decltype(MakeVDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
|
|
|
|
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kVPack>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -893,30 +1093,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kNPerBlock * kKPerBlock);
|
|
|
|
|
return v_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kBlockSize = Problem::kBlockSize;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t K1 = GetAlignmentK<Problem>();
|
|
|
|
|
constexpr index_t K0 = kKPerBlock / K1;
|
|
|
|
|
constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
|
|
|
|
|
constexpr index_t N1 = get_warp_size() / K0;
|
|
|
|
|
constexpr index_t N0 = kBlockSize / get_warp_size();
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<2, 1>,
|
|
|
|
|
sequence<1, 2>>{});
|
|
|
|
|
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
static_assert(y_ndim >= 2);
|
|
|
|
|
using shuffled_encoding_t =
|
|
|
|
|
tile_distribution_encoding_shuffle_t<dram_encoding,
|
|
|
|
|
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
|
|
|
|
return make_static_tile_distribution(shuffled_encoding_t{});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -926,10 +1117,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem,
|
|
|
|
|
KIter,
|
|
|
|
|
kNPerBlock / KIter,
|
|
|
|
|
kKPerBlock,
|
|
|
|
|
kKPack,
|
|
|
|
|
kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -976,7 +1187,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(kt_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kNPerBlock * kKPerBlock);
|
|
|
|
|
return kt_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -986,9 +1199,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
|
|
|
|
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1019,30 +1246,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock * kKPerBlock);
|
|
|
|
|
return q_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kBlockSize = Problem::kBlockSize;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
|
|
|
|
constexpr index_t K0 = kKPerBlock / K1;
|
|
|
|
|
constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
|
|
|
|
|
constexpr index_t N1 = get_warp_size() / K0;
|
|
|
|
|
constexpr index_t N0 = kBlockSize / get_warp_size();
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<2, 1>,
|
|
|
|
|
sequence<1, 2>>{});
|
|
|
|
|
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
static_assert(y_ndim >= 2);
|
|
|
|
|
using shuffled_encoding_t =
|
|
|
|
|
tile_distribution_encoding_shuffle_t<dram_encoding,
|
|
|
|
|
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
|
|
|
|
return make_static_tile_distribution(shuffled_encoding_t{});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1052,10 +1270,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem,
|
|
|
|
|
KIter,
|
|
|
|
|
kNPerBlock / KIter,
|
|
|
|
|
kKPerBlock,
|
|
|
|
|
kKPack,
|
|
|
|
|
kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1103,6 +1341,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
|
|
|
|
|
static_assert(container_reduce(qt_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kNPerBlock * kKPerBlock);
|
|
|
|
|
|
|
|
|
|
return qt_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
@@ -1135,7 +1376,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(dst_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kMPerBlock * kKPerBlock);
|
|
|
|
|
return dst_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1177,13 +1420,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t M1 = MWarp;
|
|
|
|
|
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
constexpr auto dstr = make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<N0, N1>,
|
|
|
|
|
tuple<sequence<M0, M1, M2, M3, M4>>,
|
|
|
|
|
tuple<sequence<1, 0>, sequence<1, 0>>,
|
|
|
|
|
tuple<sequence<1, 0>, sequence<3, 1>>,
|
|
|
|
|
sequence<1, 1, 1>,
|
|
|
|
|
sequence<0, 2, 4>>{});
|
|
|
|
|
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
|
|
|
|
kMPerBlock);
|
|
|
|
|
return dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1193,9 +1439,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
using dram_encoding =
|
|
|
|
|
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
|
|
|
|
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1226,30 +1487,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(do_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kMPerBlock * kKPerBlock);
|
|
|
|
|
return do_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kBlockSize = Problem::kBlockSize;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
|
|
|
|
|
|
|
|
|
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
|
|
|
|
constexpr index_t K0 = kKPerBlock / K1;
|
|
|
|
|
constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
|
|
|
|
|
constexpr index_t N1 = get_warp_size() / K0;
|
|
|
|
|
constexpr index_t N0 = kBlockSize / get_warp_size();
|
|
|
|
|
|
|
|
|
|
return make_static_tile_distribution(
|
|
|
|
|
tile_distribution_encoding<sequence<>,
|
|
|
|
|
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
|
|
|
|
tuple<sequence<1>, sequence<1, 2>>,
|
|
|
|
|
tuple<sequence<0>, sequence<1, 0>>,
|
|
|
|
|
sequence<2, 1>,
|
|
|
|
|
sequence<1, 2>>{});
|
|
|
|
|
using dram_encoding =
|
|
|
|
|
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
static_assert(y_ndim >= 2);
|
|
|
|
|
using shuffled_encoding_t =
|
|
|
|
|
tile_distribution_encoding_shuffle_t<dram_encoding,
|
|
|
|
|
remove_cvref_t<decltype(swap_last2<y_ndim>)>>;
|
|
|
|
|
return make_static_tile_distribution(shuffled_encoding_t{});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1259,10 +1514,31 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
|
|
|
|
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
|
|
|
|
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
|
|
|
|
|
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
using dram_encoding =
|
|
|
|
|
typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
|
|
|
|
|
constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
|
|
|
|
|
if constexpr(dram_y_ndim == 2)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
|
|
|
|
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(dram_y_ndim == 3)
|
|
|
|
|
{
|
|
|
|
|
constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
|
|
|
|
|
constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
|
|
|
|
|
constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
|
|
|
|
|
return MakeXTLdsBlockDescriptor<Problem,
|
|
|
|
|
KIter,
|
|
|
|
|
kNPerBlock / KIter,
|
|
|
|
|
kKPerBlock,
|
|
|
|
|
kKPack,
|
|
|
|
|
kKPackT>();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
static_assert(false, "Unexpected dram y dimension");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Problem>
|
|
|
|
|
@@ -1310,7 +1586,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(dot_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kNPerBlock * kKPerBlock);
|
|
|
|
|
return dot_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1342,7 +1620,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(pt_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kMPerBlock * kKPerBlock);
|
|
|
|
|
return pt_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1384,7 +1664,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|
|
|
|
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
|
|
|
|
|
|
|
|
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
|
|
|
|
|
|
|
|
|
|
static_assert(container_reduce(ds_block_dstr.get_lengths(),
|
|
|
|
|
std::multiplies<index_t>{},
|
|
|
|
|
1) == kMPerBlock * kKPerBlock);
|
|
|
|
|
return ds_block_dstr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|