Simplifying the codes in defining KDram and QDram tile distribution

This commit is contained in:
Qianfeng Zhang
2025-12-14 13:50:49 +00:00
parent 1ab5e9da93
commit 125934a966

View File

@@ -173,9 +173,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize();
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
// try to avoid writing sub-dword to LDS due to poor performance
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
@@ -330,19 +329,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
@@ -362,19 +355,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
@@ -511,18 +498,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t kKVector = GetAlignmentK<Problem>();
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();