Merge some updates for ck_tile headers (#3342)

* fix some issues from internal branch

* update cshuffle_epilogue

* update cshuffle_epilogue

* update cshuffle

* update warp_gemm
This commit is contained in:
joyeamd
2026-01-06 15:39:00 +08:00
committed by GitHub
parent 2b563ad048
commit b78563b3d3
14 changed files with 205 additions and 119 deletions

View File

@@ -845,10 +845,10 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
16);
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
return smem_size_a;
}
@@ -859,8 +859,9 @@ struct UniversalGemmBasePolicy
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
return smem_size_b;
}

View File

@@ -53,11 +53,11 @@ struct TileGemmUniversalTraits
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
static constexpr bool TransposeC = TransposeC_;
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;