mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Merge some updates for ck_tile headers (#3342)
* fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm
This commit is contained in:
@@ -845,10 +845,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(sizeof(typename Problem::ADataType) *
|
||||
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
@@ -859,8 +859,9 @@ struct UniversalGemmBasePolicy
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
|
||||
@@ -53,11 +53,11 @@ struct TileGemmUniversalTraits
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
|
||||
Reference in New Issue
Block a user