Update in K-Lds laying-out to consider for both WarpGemm-32x32x16 and WarpGemm-16x16x16

This commit is contained in:
Qianfeng Zhang
2025-04-24 15:02:57 +00:00
parent cea919aefb
commit a41371f734

View File

@@ -51,11 +51,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
return 8 / sizeof(QKVDataType);
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
@@ -66,9 +78,18 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
static_assert(kKVector % kKPack == 0);
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
return kKPerBlock * kNPerBlock + kKPerBlock;
}
else
{
static_assert(kKVector % kKPack == 0);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
};
};
template <typename Problem>
@@ -105,40 +126,76 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
static_assert(kKVector % kKPack == 0);
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
constexpr index_t KSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
constexpr index_t KSingleSmemElementSpaceSize = kKPerBlock * kNPerBlock + kKPerBlock;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKPack + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
return k_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr index_t KSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
};
}
template <typename Problem>