Improve the VDramTileDistribution and VLds layout for better device loading and reduce bank-conflict

This commit is contained in:
Qianfeng Zhang
2025-06-08 11:22:21 +00:00
parent 84eb9adc71
commit 4632d30cc0
2 changed files with 237 additions and 112 deletions

View File

@@ -480,7 +480,7 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffle_tmp, v_tile);
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer

View File

@@ -53,6 +53,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
@@ -104,15 +114,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
return min(ElemPerThread, MaxVectorSize);
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
@@ -121,12 +126,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
@@ -137,13 +145,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return kVecLoad;
}
else
else // Similar to GetAlignmentK()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
return min(ElemPerThread, MaxVectorSize);
}
@@ -174,19 +177,38 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
return N0 * (N1 * kKPerBlock + kKPack);
}
else // similar to GetKSingleSmemElementSpaceSize()
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKVector = GetAlignmentV<Problem>();
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
return kKPerBlock * kNPerBlock + kKPerBlock;
}
else
{
static_assert(kKVector % kKPack == 0);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
};
};
};
template <typename Problem>
@@ -376,9 +398,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
@@ -400,51 +421,136 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t VSingleSmemElementSpaceSize =
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
static_assert(kKPack >= K2, "Check failed!");
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
return v_lds_block_desc;
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<8>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else // Similar to MakeKLdsBlockDescriptor()
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKVector = GetAlignmentV<Problem>();
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
constexpr index_t VSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock;
static_assert(VSingleSmemElementSpaceSize ==
GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize =
GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKPack + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr index_t VSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
static_assert(VSingleSmemElementSpaceSize ==
GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize =
GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
};
}
}
template <typename Problem>
@@ -456,66 +562,85 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K3 = ElemPerThread / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
else // Similar to MakeKDramTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
{
// This tile-distribuiton only used when V layout is RowMajor
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
{