mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Improve the VDramTileDistribution and VLds layout for better device loading and reduce bank-conflict
This commit is contained in:
@@ -480,7 +480,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
|
||||
@@ -53,6 +53,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
@@ -104,15 +114,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -121,12 +126,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
@@ -137,13 +145,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
else
|
||||
else // Similar to GetAlignmentK()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
@@ -174,19 +177,38 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
}
|
||||
else // similar to GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentV<Problem>();
|
||||
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
@@ -376,9 +398,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
@@ -400,51 +421,136 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
// K2 is the vector size for storing shuffled tile to LDS
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
static_assert(kKPack >= K2, "Check failed!");
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
|
||||
|
||||
return v_lds_block_desc;
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<N1 * kKPerBlock + kKPack>{},
|
||||
number<kKPerBlock>{},
|
||||
number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else // Similar to MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentV<Problem>();
|
||||
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize ==
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKPack + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize ==
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -456,66 +562,85 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
constexpr index_t K3 = ElemPerThread / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
else // Similar to MakeKDramTileDistribution()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
|
||||
{
|
||||
// This tile-distribuiton only used when V layout is RowMajor
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user